Skip to content

Commit e147379

Browse files
authored
Merge branch 'main' into main
2 parents 584b513 + 5a82138 commit e147379

19 files changed

Lines changed: 255 additions & 90 deletions

File tree

invokeai/app/invocations/compel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,13 +108,14 @@ def _lora_loader():
108108
print(f'Warn: trigger: "{trigger}" not found')
109109

110110
with (
111-
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
112111
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
113112
tokenizer,
114113
ti_manager,
115114
),
116115
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
117116
text_encoder_info as text_encoder,
117+
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
118+
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
118119
):
119120
compel = Compel(
120121
tokenizer=tokenizer,
@@ -229,13 +230,14 @@ def _lora_loader():
229230
print(f'Warn: trigger: "{trigger}" not found')
230231

231232
with (
232-
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
233233
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
234234
tokenizer,
235235
ti_manager,
236236
),
237237
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
238238
text_encoder_info as text_encoder,
239+
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
240+
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
239241
):
240242
compel = Compel(
241243
tokenizer=tokenizer,

invokeai/app/invocations/latent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,9 +710,10 @@ def _lora_loader():
710710
)
711711
with (
712712
ExitStack() as exit_stack,
713-
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
714713
set_seamless(unet_info.context.model, self.unet.seamless_axes),
715714
unet_info as unet,
715+
# Apply the LoRA after unet has been moved to its target device for faster patching.
716+
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
716717
):
717718
latents = latents.to(device=unet.device, dtype=unet.dtype)
718719
if noise is not None:

invokeai/app/services/config/config_default.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ram: 13.5
4646
vram: 0.25
4747
lazy_offload: true
48+
log_memory_usage: false
4849
Device:
4950
device: auto
5051
precision: auto
@@ -261,6 +262,7 @@ class InvokeAIAppConfig(InvokeAISettings):
261262
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
262263
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
263264
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
265+
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
264266

265267
# DEVICE
266268
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)

invokeai/backend/model_management/lora.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import copy
3+
import pickle
44
from contextlib import contextmanager
55
from pathlib import Path
66
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -54,24 +54,6 @@ def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tup
5454

5555
return (module_key, module)
5656

57-
@staticmethod
58-
def _lora_forward_hook(
59-
applied_loras: List[Tuple[LoRAModel, float]],
60-
layer_name: str,
61-
):
62-
def lora_forward(module, input_h, output):
63-
if len(applied_loras) == 0:
64-
return output
65-
66-
for lora, weight in applied_loras:
67-
layer = lora.layers.get(layer_name, None)
68-
if layer is None:
69-
continue
70-
output += layer.forward(module, input_h, weight)
71-
return output
72-
73-
return lora_forward
74-
7557
@classmethod
7658
@contextmanager
7759
def apply_lora_unet(
@@ -129,21 +111,40 @@ def apply_lora(
129111
if not layer_key.startswith(prefix):
130112
continue
131113

114+
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
115+
# should be improved in the following ways:
116+
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
117+
# LoRA model is applied.
118+
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
119+
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
120+
# weights to have valid keys.
132121
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
122+
123+
# All of the LoRA weight calculations will be done on the same device as the module weight.
124+
# (Performance will be best if this is a CUDA device.)
125+
device = module.weight.device
126+
dtype = module.weight.dtype
127+
133128
if module_key not in original_weights:
134129
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
135130

136-
# enable autocast to calc fp16 loras on cpu
137-
# with torch.autocast(device_type="cpu"):
138-
layer.to(dtype=torch.float32)
139131
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
140-
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
132+
133+
# We intentionally move to the target device first, then cast. Experimentally, this was found to
134+
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
135+
# same thing in a single call to '.to(...)'.
136+
layer.to(device=device)
137+
layer.to(dtype=torch.float32)
138+
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
139+
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
140+
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
141+
layer.to(device="cpu")
141142

142143
if module.weight.shape != layer_weight.shape:
143144
# TODO: debug on lycoris
144145
layer_weight = layer_weight.reshape(module.weight.shape)
145146

146-
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
147+
module.weight += layer_weight.to(dtype=dtype)
147148

148149
yield # wait for context manager exit
149150

@@ -164,7 +165,13 @@ def apply_ti(
164165
new_tokens_added = None
165166

166167
try:
167-
ti_tokenizer = copy.deepcopy(tokenizer)
168+
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
169+
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
170+
# exiting this `apply_ti(...)` context manager.
171+
#
172+
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
173+
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
174+
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
168175
ti_manager = TextualInversionManager(ti_tokenizer)
169176
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
170177

@@ -196,7 +203,9 @@ def _get_trigger(ti_name, index):
196203

197204
if model_embeddings.weight.data[token_id].shape != embedding.shape:
198205
raise ValueError(
199-
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
206+
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
207+
f" {embedding.shape[0]}, but the current model has token dimension"
208+
f" {model_embeddings.weight.data[token_id].shape[0]}."
200209
)
201210

202211
model_embeddings.weight.data[token_id] = embedding.to(
@@ -257,7 +266,8 @@ def from_checkpoint(
257266
if "string_to_param" in state_dict:
258267
if len(state_dict["string_to_param"]) > 1:
259268
print(
260-
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
269+
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
270+
" token will be used."
261271
)
262272

263273
result.embedding = next(iter(state_dict["string_to_param"].values()))
@@ -435,7 +445,13 @@ def apply_ti(
435445
orig_embeddings = None
436446

437447
try:
438-
ti_tokenizer = copy.deepcopy(tokenizer)
448+
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
449+
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
450+
# exiting this `apply_ti(...)` context manager.
451+
#
452+
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
453+
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
454+
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
439455
ti_manager = TextualInversionManager(ti_tokenizer)
440456

441457
def _get_trigger(ti_name, index):
@@ -470,7 +486,9 @@ def _get_trigger(ti_name, index):
470486

471487
if embeddings[token_id].shape != embedding.shape:
472488
raise ValueError(
473-
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
489+
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
490+
f" {embedding.shape[0]}, but the current model has token dimension"
491+
f" {embeddings[token_id].shape[0]}."
474492
)
475493

476494
embeddings[token_id] = embedding

invokeai/backend/model_management/memory_snapshot.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def capture(cls, run_garbage_collector: bool = True):
6464
return cls(process_ram, vram, malloc_info)
6565

6666

67-
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
67+
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
6868
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
6969

7070
def get_msg_line(prefix: str, val1: int, val2: int):
@@ -73,6 +73,9 @@ def get_msg_line(prefix: str, val1: int, val2: int):
7373

7474
msg = ""
7575

76+
if snapshot_1 is None or snapshot_2 is None:
77+
return msg
78+
7679
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
7780

7881
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:

invokeai/backend/model_management/model_cache.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(
117117
lazy_offloading: bool = True,
118118
sha_chunksize: int = 16777216,
119119
logger: types.ModuleType = logger,
120+
log_memory_usage: bool = False,
120121
):
121122
"""
122123
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@@ -126,6 +127,10 @@ def __init__(
126127
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
127128
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
128129
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
130+
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
131+
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
132+
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
133+
behaviour.
129134
"""
130135
self.model_infos: Dict[str, ModelBase] = dict()
131136
# allow lazy offloading only when vram cache enabled
@@ -137,13 +142,19 @@ def __init__(
137142
self.storage_device: torch.device = storage_device
138143
self.sha_chunksize = sha_chunksize
139144
self.logger = logger
145+
self._log_memory_usage = log_memory_usage
140146

141147
# used for stats collection
142148
self.stats = None
143149

144150
self._cached_models = dict()
145151
self._cache_stack = list()
146152

153+
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
154+
if self._log_memory_usage:
155+
return MemorySnapshot.capture()
156+
return None
157+
147158
def get_key(
148159
self,
149160
model_path: str,
@@ -223,10 +234,10 @@ def get_model(
223234

224235
# Load the model from disk and capture a memory snapshot before/after.
225236
start_load_time = time.time()
226-
snapshot_before = MemorySnapshot.capture()
237+
snapshot_before = self._capture_memory_snapshot()
227238
with skip_torch_weight_init():
228239
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
229-
snapshot_after = MemorySnapshot.capture()
240+
snapshot_after = self._capture_memory_snapshot()
230241
end_load_time = time.time()
231242

232243
self_reported_model_size_after_load = model_info.get_size(submodel)
@@ -275,9 +286,9 @@ def _move_model_to_device(self, key: str, target_device: torch.device):
275286
return
276287

277288
start_model_to_time = time.time()
278-
snapshot_before = MemorySnapshot.capture()
289+
snapshot_before = self._capture_memory_snapshot()
279290
cache_entry.model.to(target_device)
280-
snapshot_after = MemorySnapshot.capture()
291+
snapshot_after = self._capture_memory_snapshot()
281292
end_model_to_time = time.time()
282293
self.logger.debug(
283294
f"Moved model '{key}' from {source_device} to"
@@ -286,7 +297,12 @@ def _move_model_to_device(self, key: str, target_device: torch.device):
286297
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
287298
)
288299

289-
if snapshot_before.vram is not None and snapshot_after.vram is not None:
300+
if (
301+
snapshot_before is not None
302+
and snapshot_after is not None
303+
and snapshot_before.vram is not None
304+
and snapshot_after.vram is not None
305+
):
290306
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
291307

292308
# If the estimated model size does not match the change in VRAM, log a warning.
@@ -422,12 +438,17 @@ def _make_cache_room(self, model_size):
422438
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
423439

424440
pos = 0
441+
models_cleared = 0
425442
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
426443
model_key = self._cache_stack[pos]
427444
cache_entry = self._cached_models[model_key]
428445

429446
refs = sys.getrefcount(cache_entry.model)
430447

448+
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
449+
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
450+
# https://docs.python.org/3/library/gc.html#gc.get_referrers
451+
431452
# manualy clear local variable references of just finished function calls
432453
# for some reason python don't want to collect it even by gc.collect() immidiately
433454
if refs > 2:
@@ -453,15 +474,16 @@ def _make_cache_room(self, model_size):
453474
f" refs: {refs}"
454475
)
455476

456-
# 2 refs:
477+
# Expected refs:
457478
# 1 from cache_entry
458479
# 1 from getrefcount function
459480
# 1 from onnx runtime object
460-
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
481+
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
461482
self.logger.debug(
462483
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
463484
)
464485
current_size -= cache_entry.size
486+
models_cleared += 1
465487
if self.stats:
466488
self.stats.cleared += 1
467489
del self._cache_stack[pos]
@@ -471,7 +493,20 @@ def _make_cache_room(self, model_size):
471493
else:
472494
pos += 1
473495

474-
gc.collect()
496+
if models_cleared > 0:
497+
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
498+
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
499+
# is high even if no garbage gets collected.)
500+
#
501+
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
502+
# - If models had to be cleared, it's a signal that we are close to our memory limit.
503+
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
504+
# collected.
505+
#
506+
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
507+
# immediately when their reference count hits 0.
508+
gc.collect()
509+
475510
torch.cuda.empty_cache()
476511
if choose_torch_device() == torch.device("mps"):
477512
mps.empty_cache()
@@ -491,7 +526,6 @@ def _offload_unlocked_models(self, size_needed: int = 0):
491526
vram_in_use = torch.cuda.memory_allocated()
492527
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
493528

494-
gc.collect()
495529
torch.cuda.empty_cache()
496530
if choose_torch_device() == torch.device("mps"):
497531
mps.empty_cache()

invokeai/backend/model_management/model_load_optimizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def skip_torch_weight_init():
1717
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
1818
monkey-patches common torch layers to skip the weight initialization step.
1919
"""
20-
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
20+
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
2121
saved_functions = [m.reset_parameters for m in torch_modules]
2222

2323
try:

invokeai/backend/model_management/model_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def __init__(
351351
precision=precision,
352352
sequential_offload=sequential_offload,
353353
logger=logger,
354+
log_memory_usage=self.app_config.log_memory_usage,
354355
)
355356

356357
self._read_models(config)

0 commit comments

Comments
 (0)