@@ -43,10 +43,10 @@ class ModelLoaderOutput(BaseInvocationOutput):
4343 #fmt: on
4444
4545
46- class ModelLoaderInvocation (BaseInvocation ):
46+ class SD1ModelLoaderInvocation (BaseInvocation ):
4747 """Loading submodels of selected model."""
4848
49- type : Literal ["model_loader " ] = "model_loader "
49+ type : Literal ["sd1_model_loader " ] = "sd1_model_loader "
5050
5151 model_name : str = Field (default = "" , description = "Model to load" )
5252 # TODO: precision?
@@ -64,7 +64,110 @@ class Config(InvocationConfig):
6464
6565 def invoke (self , context : InvocationContext ) -> ModelLoaderOutput :
6666
67- base_model = BaseModelType .StableDiffusion1_5 # TODO:
67+ base_model = BaseModelType .StableDiffusion1 # TODO:
68+
69+ # TODO: not found exceptions
70+ if not context .services .model_manager .model_exists (
71+ model_name = self .model_name ,
72+ base_model = base_model ,
73+ model_type = ModelType .Pipeline ,
74+ ):
75+ raise Exception (f"Unkown model name: { self .model_name } !" )
76+
77+ """
78+ if not context.services.model_manager.model_exists(
79+ model_name=self.model_name,
80+ model_type=SDModelType.Diffusers,
81+ submodel=SDModelType.Tokenizer,
82+ ):
83+ raise Exception(
84+ f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
85+ )
86+
87+ if not context.services.model_manager.model_exists(
88+ model_name=self.model_name,
89+ model_type=SDModelType.Diffusers,
90+ submodel=SDModelType.TextEncoder,
91+ ):
92+ raise Exception(
93+ f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
94+ )
95+
96+ if not context.services.model_manager.model_exists(
97+ model_name=self.model_name,
98+ model_type=SDModelType.Diffusers,
99+ submodel=SDModelType.UNet,
100+ ):
101+ raise Exception(
102+ f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
103+ )
104+ """
105+
106+
107+ return ModelLoaderOutput (
108+ unet = UNetField (
109+ unet = ModelInfo (
110+ model_name = self .model_name ,
111+ base_model = base_model ,
112+ model_type = ModelType .Pipeline ,
113+ submodel = SubModelType .UNet ,
114+ ),
115+ scheduler = ModelInfo (
116+ model_name = self .model_name ,
117+ base_model = base_model ,
118+ model_type = ModelType .Pipeline ,
119+ submodel = SubModelType .Scheduler ,
120+ ),
121+ loras = [],
122+ ),
123+ clip = ClipField (
124+ tokenizer = ModelInfo (
125+ model_name = self .model_name ,
126+ base_model = base_model ,
127+ model_type = ModelType .Pipeline ,
128+ submodel = SubModelType .Tokenizer ,
129+ ),
130+ text_encoder = ModelInfo (
131+ model_name = self .model_name ,
132+ base_model = base_model ,
133+ model_type = ModelType .Pipeline ,
134+ submodel = SubModelType .TextEncoder ,
135+ ),
136+ loras = [],
137+ ),
138+ vae = VaeField (
139+ vae = ModelInfo (
140+ model_name = self .model_name ,
141+ base_model = base_model ,
142+ model_type = ModelType .Pipeline ,
143+ submodel = SubModelType .Vae ,
144+ ),
145+ )
146+ )
147+
148+ # TODO: optimize(less code copy)
149+ class SD2ModelLoaderInvocation (BaseInvocation ):
150+ """Loading submodels of selected model."""
151+
152+ type : Literal ["sd2_model_loader" ] = "sd2_model_loader"
153+
154+ model_name : str = Field (default = "" , description = "Model to load" )
155+ # TODO: precision?
156+
157+ # Schema customisation
158+ class Config (InvocationConfig ):
159+ schema_extra = {
160+ "ui" : {
161+ "tags" : ["model" , "loader" ],
162+ "type_hints" : {
163+ "model_name" : "model" # TODO: rename to model_name?
164+ }
165+ },
166+ }
167+
168+ def invoke (self , context : InvocationContext ) -> ModelLoaderOutput :
169+
170+ base_model = BaseModelType .StableDiffusion2 # TODO:
68171
69172 # TODO: not found exceptions
70173 if not context .services .model_manager .model_exists (
0 commit comments