Skip to content

Commit 9738b0f

Browse files
[nodes] Add Edge data type (invoke-ai#2958)
Adds an `Edge` data type, replacing the current tuple used for edges.
2 parents 6eeaf8d + 3021c78 commit 9738b0f

6 files changed

Lines changed: 136 additions & 123 deletions

File tree

invokeai/app/api/routers/sessions.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ...invocations import *
1111
from ...invocations.baseinvocation import BaseInvocation
1212
from ...services.graph import (
13+
Edge,
1314
EdgeConnection,
1415
Graph,
1516
GraphExecutionState,
@@ -92,7 +93,7 @@ async def get_session(
9293
async def add_node(
9394
session_id: str = Path(description="The id of the session"),
9495
node: Annotated[
95-
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
96+
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
9697
] = Body(description="The node to add"),
9798
) -> str:
9899
"""Adds a node to the graph"""
@@ -125,7 +126,7 @@ async def update_node(
125126
session_id: str = Path(description="The id of the session"),
126127
node_path: str = Path(description="The path to the node in the graph"),
127128
node: Annotated[
128-
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
129+
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
129130
] = Body(description="The new node"),
130131
) -> GraphExecutionState:
131132
"""Updates a node in the graph and removes all linked edges"""
@@ -186,7 +187,7 @@ async def delete_node(
186187
)
187188
async def add_edge(
188189
session_id: str = Path(description="The id of the session"),
189-
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
190+
edge: Edge = Body(description="The edge to add"),
190191
) -> GraphExecutionState:
191192
"""Adds an edge to the graph"""
192193
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@@ -228,9 +229,9 @@ async def delete_edge(
228229
return Response(status_code=404)
229230

230231
try:
231-
edge = (
232-
EdgeConnection(node_id=from_node_id, field=from_field),
233-
EdgeConnection(node_id=to_node_id, field=to_field),
232+
edge = Edge(
233+
source=EdgeConnection(node_id=from_node_id, field=from_field),
234+
destination=EdgeConnection(node_id=to_node_id, field=to_field)
234235
)
235236
session.delete_edge(edge)
236237
ApiDependencies.invoker.services.graph_execution_manager.set(

invokeai/app/cli_app.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .services.events import EventServiceBase
2020
from .services.model_manager_initializer import get_model_manager
2121
from .services.restoration_services import RestorationServices
22-
from .services.graph import EdgeConnection, GraphExecutionState
22+
from .services.graph import Edge, EdgeConnection, GraphExecutionState
2323
from .services.image_storage import DiskImageStorage
2424
from .services.invocation_queue import MemoryInvocationQueue
2525
from .services.invocation_services import InvocationServices
@@ -77,7 +77,7 @@ def exit(*args, **kwargs):
7777

7878
def generate_matching_edges(
7979
a: BaseInvocation, b: BaseInvocation
80-
) -> list[tuple[EdgeConnection, EdgeConnection]]:
80+
) -> list[Edge]:
8181
"""Generates all possible edges between two invocations"""
8282
atype = type(a)
8383
btype = type(b)
@@ -94,9 +94,9 @@ def generate_matching_edges(
9494
matching_fields = matching_fields.difference(invalid_fields)
9595

9696
edges = [
97-
(
98-
EdgeConnection(node_id=a.id, field=field),
99-
EdgeConnection(node_id=b.id, field=field),
97+
Edge(
98+
source=EdgeConnection(node_id=a.id, field=field),
99+
destination=EdgeConnection(node_id=b.id, field=field)
100100
)
101101
for field in matching_fields
102102
]
@@ -111,16 +111,15 @@ class SessionError(Exception):
111111
def invoke_all(context: CliContext):
112112
"""Runs all invocations in the specified session"""
113113
context.invoker.invoke(context.session, invoke_all=True)
114-
while not context.session.is_complete():
114+
while not context.get_session().is_complete():
115115
# Wait some time
116-
session = context.get_session()
117116
time.sleep(0.1)
118117

119118
# Print any errors
120119
if context.session.has_error():
121120
for n in context.session.errors:
122121
print(
123-
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
122+
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
124123
)
125124

126125
raise SessionError()
@@ -203,7 +202,7 @@ def invoke_cli():
203202
continue
204203

205204
# Pipe previous command output (if there was a previous command)
206-
edges = []
205+
edges: list[Edge] = list()
207206
if len(history) > 0 or current_id != start_id:
208207
from_id = (
209208
history[0] if current_id == start_id else str(current_id - 1)
@@ -225,19 +224,19 @@ def invoke_cli():
225224
matching_edges = generate_matching_edges(
226225
link_node, command.command
227226
)
228-
matching_destinations = [e[1] for e in matching_edges]
229-
edges = [e for e in edges if e[1] not in matching_destinations]
227+
matching_destinations = [e.destination for e in matching_edges]
228+
edges = [e for e in edges if e.destination not in matching_destinations]
230229
edges.extend(matching_edges)
231230

232231
if "link" in args and args["link"]:
233232
for link in args["link"]:
234-
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
233+
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
235234
edges.append(
236-
(
237-
EdgeConnection(node_id=link[1], field=link[0]),
238-
EdgeConnection(
235+
Edge(
236+
source=EdgeConnection(node_id=link[1], field=link[0]),
237+
destination=EdgeConnection(
239238
node_id=command.command.id, field=link[2]
240-
),
239+
)
241240
)
242241
)
243242

0 commit comments

Comments
 (0)