@@ -88,6 +88,7 @@ def convert_if_required(
8888 else :
8989 return model_path
9090
91+
9192class LoRALayerBase :
9293 # rank: Optional[int]
9394 # alpha: Optional[float]
@@ -173,6 +174,7 @@ def to(
173174 if self .bias is not None :
174175 self .bias = self .bias .to (device = device , dtype = dtype )
175176
177+
176178# TODO: find and debug lora/locon with bias
177179class LoRALayer (LoRALayerBase ):
178180 # up: torch.Tensor
@@ -225,6 +227,7 @@ def to(
225227 if self .mid is not None :
226228 self .mid = self .mid .to (device = device , dtype = dtype )
227229
230+
228231class LoHALayer (LoRALayerBase ):
229232 # w1_a: torch.Tensor
230233 # w1_b: torch.Tensor
@@ -292,6 +295,7 @@ def to(
292295 if self .t2 is not None :
293296 self .t2 = self .t2 .to (device = device , dtype = dtype )
294297
298+
295299class LoKRLayer (LoRALayerBase ):
296300 # w1: Optional[torch.Tensor] = None
297301 # w1_a: Optional[torch.Tensor] = None
@@ -386,6 +390,7 @@ def to(
386390 if self .t2 is not None :
387391 self .t2 = self .t2 .to (device = device , dtype = dtype )
388392
393+
389394# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
390395class LoRAModelRaw : # (torch.nn.Module):
391396 _name : str
@@ -439,7 +444,7 @@ def _convert_sdxl_compvis_keys(cls, state_dict):
439444 new_state_dict = dict ()
440445 for full_key , value in state_dict .items ():
441446 if full_key .startswith ("lora_te1_" ) or full_key .startswith ("lora_te2_" ):
442- continue # clip same
447+ continue # clip same
443448
444449 if not full_key .startswith ("lora_unet_" ):
445450 raise NotImplementedError (f"Unknown prefix for sdxl lora key - { full_key } " )
@@ -450,7 +455,7 @@ def _convert_sdxl_compvis_keys(cls, state_dict):
450455 if src_key in SDXL_UNET_COMPVIS_MAP :
451456 dst_key = SDXL_UNET_COMPVIS_MAP [src_key ]
452457 break
453- src_key = "_" .join (src_key .split ('_' )[:- 1 ])
458+ src_key = "_" .join (src_key .split ("_" )[:- 1 ])
454459
455460 if dst_key is None :
456461 raise Exception (f"Unknown sdxl lora key - { full_key } " )
@@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map():
614619
615620 return unet_conversion_map
616621
617- #_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
618- SDXL_UNET_COMPVIS_MAP = {f"{ sd } " .rstrip ("." ).replace ("." , "_" ): f"{ hf } " .rstrip ("." ).replace ("." , "_" ) for sd , hf in make_sdxl_unet_conversion_map ()}
622+
623+ # _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
624+ SDXL_UNET_COMPVIS_MAP = {
625+ f"{ sd } " .rstrip ("." ).replace ("." , "_" ): f"{ hf } " .rstrip ("." ).replace ("." , "_" )
626+ for sd , hf in make_sdxl_unet_conversion_map ()
627+ }
0 commit comments