Skip to content

Commit e7db6d8

Browse files
committed
Fix ckpt and vae conversion, migrate script, remove sd2-base
1 parent a6af7e8 commit e7db6d8

13 files changed

Lines changed: 540 additions & 378 deletions

File tree

invokeai/app/invocations/model.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ class ModelLoaderOutput(BaseInvocationOutput):
4343
#fmt: on
4444

4545

46-
class ModelLoaderInvocation(BaseInvocation):
46+
class SD1ModelLoaderInvocation(BaseInvocation):
4747
"""Loading submodels of selected model."""
4848

49-
type: Literal["model_loader"] = "model_loader"
49+
type: Literal["sd1_model_loader"] = "sd1_model_loader"
5050

5151
model_name: str = Field(default="", description="Model to load")
5252
# TODO: precision?
@@ -64,7 +64,110 @@ class Config(InvocationConfig):
6464

6565
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
6666

67-
base_model = BaseModelType.StableDiffusion1_5 # TODO:
67+
base_model = BaseModelType.StableDiffusion1 # TODO:
68+
69+
# TODO: not found exceptions
70+
if not context.services.model_manager.model_exists(
71+
model_name=self.model_name,
72+
base_model=base_model,
73+
model_type=ModelType.Pipeline,
74+
):
75+
raise Exception(f"Unkown model name: {self.model_name}!")
76+
77+
"""
78+
if not context.services.model_manager.model_exists(
79+
model_name=self.model_name,
80+
model_type=SDModelType.Diffusers,
81+
submodel=SDModelType.Tokenizer,
82+
):
83+
raise Exception(
84+
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
85+
)
86+
87+
if not context.services.model_manager.model_exists(
88+
model_name=self.model_name,
89+
model_type=SDModelType.Diffusers,
90+
submodel=SDModelType.TextEncoder,
91+
):
92+
raise Exception(
93+
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
94+
)
95+
96+
if not context.services.model_manager.model_exists(
97+
model_name=self.model_name,
98+
model_type=SDModelType.Diffusers,
99+
submodel=SDModelType.UNet,
100+
):
101+
raise Exception(
102+
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
103+
)
104+
"""
105+
106+
107+
return ModelLoaderOutput(
108+
unet=UNetField(
109+
unet=ModelInfo(
110+
model_name=self.model_name,
111+
base_model=base_model,
112+
model_type=ModelType.Pipeline,
113+
submodel=SubModelType.UNet,
114+
),
115+
scheduler=ModelInfo(
116+
model_name=self.model_name,
117+
base_model=base_model,
118+
model_type=ModelType.Pipeline,
119+
submodel=SubModelType.Scheduler,
120+
),
121+
loras=[],
122+
),
123+
clip=ClipField(
124+
tokenizer=ModelInfo(
125+
model_name=self.model_name,
126+
base_model=base_model,
127+
model_type=ModelType.Pipeline,
128+
submodel=SubModelType.Tokenizer,
129+
),
130+
text_encoder=ModelInfo(
131+
model_name=self.model_name,
132+
base_model=base_model,
133+
model_type=ModelType.Pipeline,
134+
submodel=SubModelType.TextEncoder,
135+
),
136+
loras=[],
137+
),
138+
vae=VaeField(
139+
vae=ModelInfo(
140+
model_name=self.model_name,
141+
base_model=base_model,
142+
model_type=ModelType.Pipeline,
143+
submodel=SubModelType.Vae,
144+
),
145+
)
146+
)
147+
148+
# TODO: optimize(less code copy)
149+
class SD2ModelLoaderInvocation(BaseInvocation):
150+
"""Loading submodels of selected model."""
151+
152+
type: Literal["sd2_model_loader"] = "sd2_model_loader"
153+
154+
model_name: str = Field(default="", description="Model to load")
155+
# TODO: precision?
156+
157+
# Schema customisation
158+
class Config(InvocationConfig):
159+
schema_extra = {
160+
"ui": {
161+
"tags": ["model", "loader"],
162+
"type_hints": {
163+
"model_name": "model" # TODO: rename to model_name?
164+
}
165+
},
166+
}
167+
168+
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
169+
170+
base_model = BaseModelType.StableDiffusion2 # TODO:
68171

69172
# TODO: not found exceptions
70173
if not context.services.model_manager.model_exists(

0 commit comments

Comments
 (0)