2929writes to the system log is stored in InvocationServices.performance_statistics.
3030"""
3131
32+ import psutil
3233import time
3334from abc import ABC , abstractmethod
3435from contextlib import AbstractContextManager
4243from ..invocations .baseinvocation import BaseInvocation
4344from .graph import GraphExecutionState
4445from .item_storage import ItemStorageABC
46+ from .model_manager_service import ModelManagerService
47+ from invokeai .backend .model_management .model_cache import CacheStats
48+
49+ # size of GIG in bytes
50+ GIG = 1073741824
4551
4652
4753class InvocationStatsServiceBase (ABC ):
@@ -89,6 +95,8 @@ def update_invocation_stats(
8995 invocation_type : str ,
9096 time_used : float ,
9197 vram_used : float ,
98+ ram_used : float ,
99+ ram_changed : float ,
92100 ):
93101 """
94102 Add timing information on execution of a node. Usually
@@ -97,6 +105,8 @@ def update_invocation_stats(
97105 :param invocation_type: String literal type of the node
98106 :param time_used: Time used by node's exection (sec)
99107 :param vram_used: Maximum VRAM used during exection (GB)
108+ :param ram_used: Current RAM available (GB)
109+ :param ram_changed: Change in RAM usage over course of the run (GB)
100110 """
101111 pass
102112
@@ -115,6 +125,9 @@ class NodeStats:
115125 calls : int = 0
116126 time_used : float = 0.0 # seconds
117127 max_vram : float = 0.0 # GB
128+ cache_hits : int = 0
129+ cache_misses : int = 0
130+ cache_high_watermark : int = 0
118131
119132
120133@dataclass
@@ -133,31 +146,62 @@ def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"
133146 self .graph_execution_manager = graph_execution_manager
134147 # {graph_id => NodeLog}
135148 self ._stats : Dict [str , NodeLog ] = {}
149+ self ._cache_stats : Dict [str , CacheStats ] = {}
150+ self .ram_used : float = 0.0
151+ self .ram_changed : float = 0.0
136152
137153 class StatsContext :
138- def __init__ (self , invocation : BaseInvocation , graph_id : str , collector : "InvocationStatsServiceBase" ):
154+ """Context manager for collecting statistics."""
155+
156+ invocation : BaseInvocation = None
157+ collector : "InvocationStatsServiceBase" = None
158+ graph_id : str = None
159+ start_time : int = 0
160+ ram_used : int = 0
161+ model_manager : ModelManagerService = None
162+
163+ def __init__ (
164+ self ,
165+ invocation : BaseInvocation ,
166+ graph_id : str ,
167+ model_manager : ModelManagerService ,
168+ collector : "InvocationStatsServiceBase" ,
169+ ):
170+ """Initialize statistics for this run."""
139171 self .invocation = invocation
140172 self .collector = collector
141173 self .graph_id = graph_id
142174 self .start_time = 0
175+ self .ram_used = 0
176+ self .model_manager = model_manager
143177
144178 def __enter__ (self ):
145179 self .start_time = time .time ()
146180 if torch .cuda .is_available ():
147181 torch .cuda .reset_peak_memory_stats ()
182+ self .ram_used = psutil .Process ().memory_info ().rss
183+ if self .model_manager :
184+ self .model_manager .collect_cache_stats (self .collector ._cache_stats [self .graph_id ])
148185
149186 def __exit__ (self , * args ):
187+ """Called on exit from the context."""
188+ ram_used = psutil .Process ().memory_info ().rss
189+ self .collector .update_mem_stats (
190+ ram_used = ram_used / GIG ,
191+ ram_changed = (ram_used - self .ram_used ) / GIG ,
192+ )
150193 self .collector .update_invocation_stats (
151- self .graph_id ,
152- self .invocation .type ,
153- time .time () - self .start_time ,
154- torch .cuda .max_memory_allocated () / 1e9 if torch .cuda .is_available () else 0.0 ,
194+ graph_id = self .graph_id ,
195+ invocation_type = self .invocation .type ,
196+ time_used = time .time () - self .start_time ,
197+ vram_used = torch .cuda .max_memory_allocated () / GIG if torch .cuda .is_available () else 0.0 ,
155198 )
156199
157200 def collect_stats (
158201 self ,
159202 invocation : BaseInvocation ,
160203 graph_execution_state_id : str ,
204+ model_manager : ModelManagerService ,
161205 ) -> StatsContext :
162206 """
163207 Return a context object that will capture the statistics.
@@ -166,7 +210,8 @@ def collect_stats(
166210 """
167211 if not self ._stats .get (graph_execution_state_id ): # first time we're seeing this
168212 self ._stats [graph_execution_state_id ] = NodeLog ()
169- return self .StatsContext (invocation , graph_execution_state_id , self )
213+ self ._cache_stats [graph_execution_state_id ] = CacheStats ()
214+ return self .StatsContext (invocation , graph_execution_state_id , model_manager , self )
170215
171216 def reset_all_stats (self ):
172217 """Zero all statistics"""
@@ -179,13 +224,36 @@ def reset_stats(self, graph_execution_id: str):
179224 except KeyError :
180225 logger .warning (f"Attempted to clear statistics for unknown graph { graph_execution_id } " )
181226
182- def update_invocation_stats (self , graph_id : str , invocation_type : str , time_used : float , vram_used : float ):
227+ def update_mem_stats (
228+ self ,
229+ ram_used : float ,
230+ ram_changed : float ,
231+ ):
232+ """
233+ Update the collector with RAM memory usage info.
234+
235+ :param ram_used: How much RAM is currently in use.
236+ :param ram_changed: How much RAM changed since last generation.
237+ """
238+ self .ram_used = ram_used
239+ self .ram_changed = ram_changed
240+
241+ def update_invocation_stats (
242+ self ,
243+ graph_id : str ,
244+ invocation_type : str ,
245+ time_used : float ,
246+ vram_used : float ,
247+ ):
183248 """
184249 Add timing information on execution of a node. Usually
185250 used internally.
186251 :param graph_id: ID of the graph that is currently executing
187252 :param invocation_type: String literal type of the node
188- :param time_used: Floating point seconds used by node's exection
253+ :param time_used: Time used by node's exection (sec)
254+ :param vram_used: Maximum VRAM used during exection (GB)
255+ :param ram_used: Current RAM available (GB)
256+ :param ram_changed: Change in RAM usage over course of the run (GB)
189257 """
190258 if not self ._stats [graph_id ].nodes .get (invocation_type ):
191259 self ._stats [graph_id ].nodes [invocation_type ] = NodeStats ()
@@ -197,7 +265,7 @@ def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used
197265 def log_stats (self ):
198266 """
199267 Send the statistics to the system logger at the info level.
200- Stats will only be printed if when the execution of the graph
268+ Stats will only be printed when the execution of the graph
201269 is complete.
202270 """
203271 completed = set ()
@@ -208,16 +276,30 @@ def log_stats(self):
208276
209277 total_time = 0
210278 logger .info (f"Graph stats: { graph_id } " )
211- logger .info (" Node Calls Seconds VRAM Used" )
279+ logger .info (f" { ' Node' :>30 } { ' Calls' :>7 } { ' Seconds' :>9 } { ' VRAM Used' :>10 } " )
212280 for node_type , stats in self ._stats [graph_id ].nodes .items ():
213- logger .info (f"{ node_type :<20 } { stats .calls :>5 } { stats .time_used :7.3f} s { stats .max_vram :4.2f } G" )
281+ logger .info (f"{ node_type :>30 } { stats .calls :>4 } { stats .time_used :7.3f} s { stats .max_vram :4.3f } G" )
214282 total_time += stats .time_used
215283
284+ cache_stats = self ._cache_stats [graph_id ]
285+ hwm = cache_stats .high_watermark / GIG
286+ tot = cache_stats .cache_size / GIG
287+ loaded = sum ([v for v in cache_stats .loaded_model_sizes .values ()]) / GIG
288+
216289 logger .info (f"TOTAL GRAPH EXECUTION TIME: { total_time :7.3f} s" )
290+ logger .info ("RAM used by InvokeAI process: " + "%4.2fG" % self .ram_used + f" ({ self .ram_changed :+5.3f} G)" )
291+ logger .info (f"RAM used to load models: { loaded :4.2f} G" )
217292 if torch .cuda .is_available ():
218- logger .info ("Current VRAM utilization " + "%4.2fG" % (torch .cuda .memory_allocated () / 1e9 ))
293+ logger .info ("VRAM in use: " + "%4.3fG" % (torch .cuda .memory_allocated () / GIG ))
294+ logger .info ("RAM cache statistics:" )
295+ logger .info (f" Model cache hits: { cache_stats .hits } " )
296+ logger .info (f" Model cache misses: { cache_stats .misses } " )
297+ logger .info (f" Models cached: { cache_stats .in_cache } " )
298+ logger .info (f" Models cleared from cache: { cache_stats .cleared } " )
299+ logger .info (f" Cache high water mark: { hwm :4.2f} /{ tot :4.2f} G" )
219300
220301 completed .add (graph_id )
221302
222303 for graph_id in completed :
223304 del self ._stats [graph_id ]
305+ del self ._cache_stats [graph_id ]
0 commit comments