1919from .services .events import EventServiceBase
2020from .services .model_manager_initializer import get_model_manager
2121from .services .restoration_services import RestorationServices
22- from .services .graph import EdgeConnection , GraphExecutionState
22+ from .services .graph import Edge , EdgeConnection , GraphExecutionState
2323from .services .image_storage import DiskImageStorage
2424from .services .invocation_queue import MemoryInvocationQueue
2525from .services .invocation_services import InvocationServices
@@ -77,7 +77,7 @@ def exit(*args, **kwargs):
7777
7878def generate_matching_edges (
7979 a : BaseInvocation , b : BaseInvocation
80- ) -> list [tuple [ EdgeConnection , EdgeConnection ] ]:
80+ ) -> list [Edge ]:
8181 """Generates all possible edges between two invocations"""
8282 atype = type (a )
8383 btype = type (b )
@@ -94,9 +94,9 @@ def generate_matching_edges(
9494 matching_fields = matching_fields .difference (invalid_fields )
9595
9696 edges = [
97- (
98- EdgeConnection (node_id = a .id , field = field ),
99- EdgeConnection (node_id = b .id , field = field ),
97+ Edge (
98+ source = EdgeConnection (node_id = a .id , field = field ),
99+ destination = EdgeConnection (node_id = b .id , field = field )
100100 )
101101 for field in matching_fields
102102 ]
@@ -111,16 +111,15 @@ class SessionError(Exception):
111111def invoke_all (context : CliContext ):
112112 """Runs all invocations in the specified session"""
113113 context .invoker .invoke (context .session , invoke_all = True )
114- while not context .session .is_complete ():
114+ while not context .get_session () .is_complete ():
115115 # Wait some time
116- session = context .get_session ()
117116 time .sleep (0.1 )
118117
119118 # Print any errors
120119 if context .session .has_error ():
121120 for n in context .session .errors :
122121 print (
123- f"Error in node { n } (source node { context .session .prepared_source_mapping [n ]} ): { session .errors [n ]} "
122+ f"Error in node { n } (source node { context .session .prepared_source_mapping [n ]} ): { context . session .errors [n ]} "
124123 )
125124
126125 raise SessionError ()
@@ -203,7 +202,7 @@ def invoke_cli():
203202 continue
204203
205204 # Pipe previous command output (if there was a previous command)
206- edges = []
205+ edges : list [ Edge ] = list ()
207206 if len (history ) > 0 or current_id != start_id :
208207 from_id = (
209208 history [0 ] if current_id == start_id else str (current_id - 1 )
@@ -225,19 +224,19 @@ def invoke_cli():
225224 matching_edges = generate_matching_edges (
226225 link_node , command .command
227226 )
228- matching_destinations = [e [ 1 ] for e in matching_edges ]
229- edges = [e for e in edges if e [ 1 ] not in matching_destinations ]
227+ matching_destinations = [e . destination for e in matching_edges ]
228+ edges = [e for e in edges if e . destination not in matching_destinations ]
230229 edges .extend (matching_edges )
231230
232231 if "link" in args and args ["link" ]:
233232 for link in args ["link" ]:
234- edges = [e for e in edges if e [ 1 ]. node_id != command .command .id and e [ 1 ] .field != link [2 ]]
233+ edges = [e for e in edges if e . destination . node_id != command .command .id and e . destination .field != link [2 ]]
235234 edges .append (
236- (
237- EdgeConnection (node_id = link [1 ], field = link [0 ]),
238- EdgeConnection (
235+ Edge (
236+ source = EdgeConnection (node_id = link [1 ], field = link [0 ]),
237+ destination = EdgeConnection (
239238 node_id = command .command .id , field = link [2 ]
240- ),
239+ )
241240 )
242241 )
243242
0 commit comments