1616import textwrap
1717import traceback
1818import warnings
19+ import yaml
1920from argparse import Namespace
2021from pathlib import Path
2122from shutil import get_terminal_size
2526import npyscreen
2627import transformers
2728from diffusers import AutoencoderKL
29+ from diffusers .pipelines .stable_diffusion .safety_checker import StableDiffusionSafetyChecker
2830from huggingface_hub import HfFolder
2931from huggingface_hub import login as hf_hub_login
3032from omegaconf import OmegaConf
3436 CLIPSegForImageSegmentation ,
3537 CLIPTextModel ,
3638 CLIPTokenizer ,
39+ AutoFeatureExtractor ,
40+ BertTokenizerFast ,
3741)
3842import invokeai .configs as configs
3943
5862 recommended_datasets ,
5963 UserSelections ,
6064)
65+ from invokeai .backend .model_management .model_probe import (
66+ ModelProbe , ModelType , BaseModelType , SchedulerPredictionType
67+ )
6168
6269warnings .filterwarnings ("ignore" )
6370transformers .logging .set_verbosity_error ()
8188# or renaming it and then running invokeai-configure again.
8289"""
8390
84- logger = None
91+ logger = InvokeAILogger . getLogger ()
8592
8693# --------------------------------------------
8794def postscript (errors : None ):
@@ -162,75 +169,91 @@ def __call__(self, block_num, block_size, total_size):
162169# ---------------------------------------------
163170def download_with_progress_bar (model_url : str , model_dest : str , label : str = "the" ):
164171 try :
165- print (f"Installing { label } model file { model_url } ..." , end = "" , file = sys . stderr )
172+ logger . info (f"Installing { label } model file { model_url } ..." )
166173 if not os .path .exists (model_dest ):
167174 os .makedirs (os .path .dirname (model_dest ), exist_ok = True )
168175 request .urlretrieve (
169176 model_url , model_dest , ProgressBar (os .path .basename (model_dest ))
170177 )
171- print ("...downloaded successfully" , file = sys . stderr )
178+ logger . info ("...downloaded successfully" )
172179 else :
173- print ("...exists" , file = sys . stderr )
180+ logger . info ("...exists" )
174181 except Exception :
175- print ("...download failed" , file = sys . stderr )
176- print (f"Error downloading { label } model" , file = sys . stderr )
182+ logger . info ("...download failed" )
183+ logger . info (f"Error downloading { label } model" )
177184 print (traceback .format_exc (), file = sys .stderr )
178185
179186
180- # ---------------------------------------------
181- # this will preload the Bert tokenizer fles
182- def download_bert ():
183- print ("Installing bert tokenizer..." , file = sys .stderr )
184- with warnings .catch_warnings ():
185- warnings .filterwarnings ("ignore" , category = DeprecationWarning )
186- from transformers import BertTokenizerFast
187-
188- download_from_hf (BertTokenizerFast , "bert-base-uncased" )
189-
190-
191- # ---------------------------------------------
192- def download_sd1_clip ():
193- print ("Installing SD1 clip model..." , file = sys .stderr )
194- version = "openai/clip-vit-large-patch14"
195- download_from_hf (CLIPTokenizer , version )
196- download_from_hf (CLIPTextModel , version )
197-
198-
199- # ---------------------------------------------
200- def download_sd2_clip ():
201- version = "stabilityai/stable-diffusion-2"
202- print ("Installing SD2 clip model..." , file = sys .stderr )
203- download_from_hf (CLIPTokenizer , version , subfolder = "tokenizer" )
204- download_from_hf (CLIPTextModel , version , subfolder = "text_encoder" )
187+ def download_conversion_models ():
188+ target_dir = config .root_path / 'models/core/convert'
189+ kwargs = dict () # for future use
190+ try :
191+ logger .info ('Downloading core tokenizers and text encoders' )
205192
193+ # bert
194+ with warnings .catch_warnings ():
195+ warnings .filterwarnings ("ignore" , category = DeprecationWarning )
196+ bert = BertTokenizerFast .from_pretrained ("bert-base-uncased" , ** kwargs )
197+ bert .save_pretrained (target_dir / 'bert-base-uncased' , safe_serialization = True )
198+
199+ # sd-1
200+ repo_id = 'openai/clip-vit-large-patch14'
201+ download_from_hf (CLIPTokenizer , repo_id , target_dir / 'clip-vit-large-patch14' )
202+ download_from_hf (CLIPTextModel , repo_id , target_dir / 'clip-vit-large-patch14' )
203+
204+ # sd-2
205+ repo_id = "stabilityai/stable-diffusion-2"
206+ pipeline = CLIPTokenizer .from_pretrained (repo_id , subfolder = "tokenizer" , ** kwargs )
207+ pipeline .save_pretrained (target_dir / 'stable-diffusion-2-clip' / 'tokenizer' , safe_serialization = True )
208+
209+ pipeline = CLIPTextModel .from_pretrained (repo_id , subfolder = "text_encoder" , ** kwargs )
210+ pipeline .save_pretrained (target_dir / 'stable-diffusion-2-clip' / 'text_encoder' , safe_serialization = True )
211+
212+ # VAE
213+ logger .info ('Downloading stable diffusion VAE' )
214+ vae = AutoencoderKL .from_pretrained ('stabilityai/sd-vae-ft-mse' , ** kwargs )
215+ vae .save_pretrained (target_dir / 'sd-vae-ft-mse' , safe_serialization = True )
216+
217+ # safety checking
218+ logger .info ('Downloading safety checker' )
219+ repo_id = "CompVis/stable-diffusion-safety-checker"
220+ pipeline = AutoFeatureExtractor .from_pretrained (repo_id ,** kwargs )
221+ pipeline .save_pretrained (target_dir / 'stable-diffusion-safety-checker' , safe_serialization = True )
222+
223+ pipeline = StableDiffusionSafetyChecker .from_pretrained (repo_id ,** kwargs )
224+ pipeline .save_pretrained (target_dir / 'stable-diffusion-safety-checker' , safe_serialization = True )
225+ except KeyboardInterrupt :
226+ raise
227+ except Exception as e :
228+ logger .error (str (e ))
206229
207230# ---------------------------------------------
208231def download_realesrgan ():
209- print ("Installing models from RealESRGAN..." , file = sys . stderr )
232+ logger . info ("Installing models from RealESRGAN..." )
210233 model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
211234 wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
212235
213- model_dest = config .root_path / "models/realesrgan/realesr-general-x4v3.pth"
214- wdn_model_dest = config .root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
236+ model_dest = config .root_path / "models/core/upscaling/ realesrgan/realesr-general-x4v3.pth"
237+ wdn_model_dest = config .root_path / "models/core/upscaling/ realesrgan/realesr-general-wdn-x4v3.pth"
215238
216239 download_with_progress_bar (model_url , str (model_dest ), "RealESRGAN" )
217240 download_with_progress_bar (wdn_model_url , str (wdn_model_dest ), "RealESRGANwdn" )
218241
219242
220243def download_gfpgan ():
221- print ("Installing GFPGAN models..." , file = sys . stderr )
244+ logger . info ("Installing GFPGAN models..." )
222245 for model in (
223246 [
224247 "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" ,
225- "./models/gfpgan/GFPGANv1.4.pth" ,
248+ "./models/core/face_restoration/ gfpgan/GFPGANv1.4.pth" ,
226249 ],
227250 [
228251 "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth" ,
229- "./models/gfpgan/weights/detection_Resnet50_Final.pth" ,
252+ "./models/core/face_restoration/ gfpgan/weights/detection_Resnet50_Final.pth" ,
230253 ],
231254 [
232255 "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth" ,
233- "./models/gfpgan/weights/parsing_parsenet.pth" ,
256+ "./models/core/face_restoration/ gfpgan/weights/parsing_parsenet.pth" ,
234257 ],
235258 ):
236259 model_url , model_dest = model [0 ], config .root_path / model [1 ]
@@ -239,70 +262,32 @@ def download_gfpgan():
239262
240263# ---------------------------------------------
241264def download_codeformer ():
242- print ("Installing CodeFormer model file..." , file = sys . stderr )
265+ logger . info ("Installing CodeFormer model file..." )
243266 model_url = (
244267 "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
245268 )
246- model_dest = config .root_path / "models/codeformer/codeformer.pth"
269+ model_dest = config .root_path / "models/core/face_restoration/ codeformer/codeformer.pth"
247270 download_with_progress_bar (model_url , str (model_dest ), "CodeFormer" )
248271
249272
250273# ---------------------------------------------
251274def download_clipseg ():
252- print ("Installing clipseg model for text-based masking..." , file = sys . stderr )
275+ logger . info ("Installing clipseg model for text-based masking..." )
253276 CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
254277 try :
255- download_from_hf (AutoProcessor , CLIPSEG_MODEL )
256- download_from_hf (CLIPSegForImageSegmentation , CLIPSEG_MODEL )
278+ download_from_hf (AutoProcessor , CLIPSEG_MODEL , config . root_path / 'models/core/misc/clipseg' )
279+ download_from_hf (CLIPSegForImageSegmentation , CLIPSEG_MODEL , 'models/core/misc/clipseg' )
257280 except Exception :
258- print ("Error installing clipseg model:" )
259- print (traceback .format_exc ())
281+ logger . info ("Error installing clipseg model:" )
282+ logger . info (traceback .format_exc ())
260283
261284
262- # -------------------------------------
263- def download_safety_checker ():
264- print ("Installing model for NSFW content detection..." , file = sys .stderr )
265- try :
266- from diffusers .pipelines .stable_diffusion .safety_checker import (
267- StableDiffusionSafetyChecker ,
268- )
269- from transformers import AutoFeatureExtractor
270- except ModuleNotFoundError :
271- print ("Error installing NSFW checker model:" )
272- print (traceback .format_exc ())
273- return
274- safety_model_id = "CompVis/stable-diffusion-safety-checker"
275- print ("AutoFeatureExtractor..." , file = sys .stderr )
276- download_from_hf (AutoFeatureExtractor , safety_model_id )
277- print ("StableDiffusionSafetyChecker..." , file = sys .stderr )
278- download_from_hf (StableDiffusionSafetyChecker , safety_model_id )
279-
280-
281- # -------------------------------------
282- def download_vaes ():
283- print ("Installing stabilityai VAE..." , file = sys .stderr )
284- try :
285- # first the diffusers version
286- repo_id = "stabilityai/sd-vae-ft-mse"
287- args = dict (
288- cache_dir = config .cache_dir ,
289- )
290- if not AutoencoderKL .from_pretrained (repo_id , ** args ):
291- raise Exception (f"download of { repo_id } failed" )
292-
293- repo_id = "stabilityai/sd-vae-ft-mse-original"
294- model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
295- # next the legacy checkpoint version
296- if not hf_download_with_resume (
297- repo_id = repo_id ,
298- model_name = model_name ,
299- model_dir = str (config .root_path / Model_dir / Weights_dir ),
300- ):
301- raise Exception (f"download of { model_name } failed" )
302- except Exception as e :
303- print (f"Error downloading StabilityAI standard VAE: { str (e )} " , file = sys .stderr )
304- print (traceback .format_exc (), file = sys .stderr )
305-
285+ def download_support_models ():
286+ download_realesrgan ()
287+ download_gfpgan ()
288+ download_codeformer ()
289+ download_clipseg ()
290+ download_conversion_models ()
306291
307292# -------------------------------------
308293def get_root (root : str = None ) -> str :
@@ -657,17 +642,13 @@ def default_user_selections(program_opts: Namespace) -> UserSelections:
657642
658643# -------------------------------------
659644def initialize_rootdir (root : Path , yes_to_all : bool = False ):
660- print ("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **" )
661-
645+ logger .info ("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **" )
662646 for name in (
663647 "models" ,
664- "configs" ,
665- "embeddings" ,
666648 "databases" ,
667- "loras" ,
668- "controlnets" ,
669649 "text-inversion-output" ,
670650 "text-inversion-training-data" ,
651+ "configs"
671652 ):
672653 os .makedirs (os .path .join (root , name ), exist_ok = True )
673654
@@ -676,6 +657,22 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
676657 if not os .path .samefile (configs_src , configs_dest ):
677658 shutil .copytree (configs_src , configs_dest , dirs_exist_ok = True )
678659
660+ dest = root / 'models'
661+ for model_base in [BaseModelType .StableDiffusion1 ,BaseModelType .StableDiffusion2 ]:
662+ for model_type in [ModelType .Pipeline , ModelType .Vae , ModelType .Lora ,
663+ ModelType .ControlNet ,ModelType .TextualInversion ]:
664+ path = dest / model_base .value / model_type .value
665+ path .mkdir (parents = True , exist_ok = True )
666+ path = dest / 'core'
667+ path .mkdir (parents = True , exist_ok = True )
668+
669+ with open (root / 'configs' / 'models.yaml' ,'w' ) as yaml_file :
670+ yaml_file .write (yaml .dump ({'__metadata__' :
671+ {'version' :'3.0.0' }
672+ }
673+ )
674+ )
675+
679676
680677# -------------------------------------
681678def run_console_ui (
@@ -837,7 +834,7 @@ def main():
837834 old_init_file = config .root_path / 'invokeai.init'
838835 new_init_file = config .root_path / 'invokeai.yaml'
839836 if old_init_file .exists () and not new_init_file .exists ():
840- print ('** Migrating invokeai.init to invokeai.yaml' )
837+ logger . info ('** Migrating invokeai.init to invokeai.yaml' )
841838 migrate_init_file (old_init_file )
842839 # Load new init file into config
843840 config .parse_args (argv = [],conf = OmegaConf .load (new_init_file ))
@@ -855,29 +852,21 @@ def main():
855852 if init_options :
856853 write_opts (init_options , new_init_file )
857854 else :
858- print (
855+ logger . info (
859856 '\n ** CANCELLED AT USER\' S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n '
860857 )
861858 sys .exit (0 )
862859
863860 if opt .skip_support_models :
864- print ( " \n ** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST ** " )
861+ logger . info ( " SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST" )
865862 else :
866- print ("\n ** CHECKING/UPDATING SUPPORT MODELS **" )
867- download_bert ()
868- download_sd1_clip ()
869- download_sd2_clip ()
870- download_realesrgan ()
871- download_gfpgan ()
872- download_codeformer ()
873- download_clipseg ()
874- download_safety_checker ()
875- download_vaes ()
863+ logger .info ("CHECKING/UPDATING SUPPORT MODELS" )
864+ download_support_models ()
876865
877866 if opt .skip_sd_weights :
878- print ("\n ** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **" )
867+ logger . info ("\n ** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **" )
879868 elif models_to_download :
880- print ("\n ** DOWNLOADING DIFFUSION WEIGHTS **" )
869+ logger . info ("\n ** DOWNLOADING DIFFUSION WEIGHTS **" )
881870 process_and_execute (opt , models_to_download )
882871
883872 postscript (errors = errors )
0 commit comments