Skip to content

Commit 1d5d187

Browse files
Lincoln Steinhipsterusername
authored andcommitted
model probe detects sdxl lora models
1 parent 1ac14a1 commit 1d5d187

5 files changed

Lines changed: 39 additions & 14 deletions

File tree

invokeai/app/invocations/sdxl.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,7 @@ def _lora_loader():
306306
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
307307
do_classifier_free_guidance = True
308308
cross_attention_kwargs = None
309-
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
310-
unet_info as unet:
311-
309+
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
312310
scheduler.set_timesteps(num_inference_steps, device=unet.device)
313311
timesteps = scheduler.timesteps
314312

invokeai/backend/model_management/lora.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def apply_lora_text_encoder(
101101
with cls.apply_lora(text_encoder, loras, "lora_te_"):
102102
yield
103103

104-
105104
@classmethod
106105
@contextmanager
107106
def apply_sdxl_lora_text_encoder(

invokeai/backend/model_management/model_probe.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,21 +315,38 @@ def get_format(self) -> str:
315315

316316
def get_base_type(self) -> BaseModelType:
317317
checkpoint = self.checkpoint
318+
319+
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
320+
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
321+
# misclassified as SD-1
322+
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
323+
if key in checkpoint and checkpoint[key].shape[0] == 320:
324+
return BaseModelType.StableDiffusion2
325+
326+
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
327+
if key in checkpoint:
328+
return BaseModelType.StableDiffusionXL
329+
318330
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
319-
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
331+
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
332+
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
333+
320334
lora_token_vector_length = (
321335
checkpoint[key1].shape[1]
322336
if key1 in checkpoint
323-
else checkpoint[key2].shape[0]
337+
else checkpoint[key2].shape[1]
324338
if key2 in checkpoint
325-
else 768
339+
else checkpoint[key3].shape[0]
340+
if key3 in checkpoint
341+
else None
326342
)
343+
327344
if lora_token_vector_length == 768:
328345
return BaseModelType.StableDiffusion1
329346
elif lora_token_vector_length == 1024:
330347
return BaseModelType.StableDiffusion2
331348
else:
332-
return None
349+
raise InvalidModelException(f"Unknown LoRA type")
333350

334351

335352
class TextualInversionCheckpointProbe(CheckpointProbeBase):

invokeai/backend/model_management/models/lora.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def convert_if_required(
8888
else:
8989
return model_path
9090

91+
9192
class LoRALayerBase:
9293
# rank: Optional[int]
9394
# alpha: Optional[float]
@@ -173,6 +174,7 @@ def to(
173174
if self.bias is not None:
174175
self.bias = self.bias.to(device=device, dtype=dtype)
175176

177+
176178
# TODO: find and debug lora/locon with bias
177179
class LoRALayer(LoRALayerBase):
178180
# up: torch.Tensor
@@ -225,6 +227,7 @@ def to(
225227
if self.mid is not None:
226228
self.mid = self.mid.to(device=device, dtype=dtype)
227229

230+
228231
class LoHALayer(LoRALayerBase):
229232
# w1_a: torch.Tensor
230233
# w1_b: torch.Tensor
@@ -292,6 +295,7 @@ def to(
292295
if self.t2 is not None:
293296
self.t2 = self.t2.to(device=device, dtype=dtype)
294297

298+
295299
class LoKRLayer(LoRALayerBase):
296300
# w1: Optional[torch.Tensor] = None
297301
# w1_a: Optional[torch.Tensor] = None
@@ -386,6 +390,7 @@ def to(
386390
if self.t2 is not None:
387391
self.t2 = self.t2.to(device=device, dtype=dtype)
388392

393+
389394
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
390395
class LoRAModelRaw: # (torch.nn.Module):
391396
_name: str
@@ -439,7 +444,7 @@ def _convert_sdxl_compvis_keys(cls, state_dict):
439444
new_state_dict = dict()
440445
for full_key, value in state_dict.items():
441446
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
442-
continue # clip same
447+
continue # clip same
443448

444449
if not full_key.startswith("lora_unet_"):
445450
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
@@ -450,7 +455,7 @@ def _convert_sdxl_compvis_keys(cls, state_dict):
450455
if src_key in SDXL_UNET_COMPVIS_MAP:
451456
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
452457
break
453-
src_key = "_".join(src_key.split('_')[:-1])
458+
src_key = "_".join(src_key.split("_")[:-1])
454459

455460
if dst_key is None:
456461
raise Exception(f"Unknown sdxl lora key - {full_key}")
@@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map():
614619

615620
return unet_conversion_map
616621

617-
#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
618-
SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
622+
623+
# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
624+
SDXL_UNET_COMPVIS_MAP = {
625+
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
626+
for sd, hf in make_sdxl_unet_conversion_map()
627+
}

scripts/probe-model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
parser.add_argument(
1010
"model_path",
1111
type=Path,
12+
nargs="+",
1213
)
1314
args = parser.parse_args()
1415

15-
info = ModelProbe().probe(args.model_path)
16-
print(info)
16+
for path in args.model_path:
17+
info = ModelProbe().probe(path)
18+
print(f"{path}: {info}")

0 commit comments

Comments
 (0)