Skip to content

Commit ada7399

Browse files
author
Lincoln Stein
committed
rewrite of widget display - marshalling needs rewrite
1 parent 5c74045 commit ada7399

7 files changed

Lines changed: 480 additions & 471 deletions

File tree

invokeai/backend/install/invokeai_configure.py

Lines changed: 99 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import textwrap
1717
import traceback
1818
import warnings
19+
import yaml
1920
from argparse import Namespace
2021
from pathlib import Path
2122
from shutil import get_terminal_size
@@ -25,6 +26,7 @@
2526
import npyscreen
2627
import transformers
2728
from diffusers import AutoencoderKL
29+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
2830
from huggingface_hub import HfFolder
2931
from huggingface_hub import login as hf_hub_login
3032
from omegaconf import OmegaConf
@@ -34,6 +36,8 @@
3436
CLIPSegForImageSegmentation,
3537
CLIPTextModel,
3638
CLIPTokenizer,
39+
AutoFeatureExtractor,
40+
BertTokenizerFast,
3741
)
3842
import invokeai.configs as configs
3943

@@ -58,6 +62,9 @@
5862
recommended_datasets,
5963
UserSelections,
6064
)
65+
from invokeai.backend.model_management.model_probe import (
66+
ModelProbe, ModelType, BaseModelType, SchedulerPredictionType
67+
)
6168

6269
warnings.filterwarnings("ignore")
6370
transformers.logging.set_verbosity_error()
@@ -81,7 +88,7 @@
8188
# or renaming it and then running invokeai-configure again.
8289
"""
8390

84-
logger=None
91+
logger=InvokeAILogger.getLogger()
8592

8693
# --------------------------------------------
8794
def postscript(errors: None):
@@ -162,75 +169,91 @@ def __call__(self, block_num, block_size, total_size):
162169
# ---------------------------------------------
163170
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
164171
try:
165-
print(f"Installing {label} model file {model_url}...", end="", file=sys.stderr)
172+
logger.info(f"Installing {label} model file {model_url}...")
166173
if not os.path.exists(model_dest):
167174
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
168175
request.urlretrieve(
169176
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
170177
)
171-
print("...downloaded successfully", file=sys.stderr)
178+
logger.info("...downloaded successfully")
172179
else:
173-
print("...exists", file=sys.stderr)
180+
logger.info("...exists")
174181
except Exception:
175-
print("...download failed", file=sys.stderr)
176-
print(f"Error downloading {label} model", file=sys.stderr)
182+
logger.info("...download failed")
183+
logger.info(f"Error downloading {label} model")
177184
print(traceback.format_exc(), file=sys.stderr)
178185

179186

180-
# ---------------------------------------------
181-
# this will preload the Bert tokenizer fles
182-
def download_bert():
183-
print("Installing bert tokenizer...", file=sys.stderr)
184-
with warnings.catch_warnings():
185-
warnings.filterwarnings("ignore", category=DeprecationWarning)
186-
from transformers import BertTokenizerFast
187-
188-
download_from_hf(BertTokenizerFast, "bert-base-uncased")
189-
190-
191-
# ---------------------------------------------
192-
def download_sd1_clip():
193-
print("Installing SD1 clip model...", file=sys.stderr)
194-
version = "openai/clip-vit-large-patch14"
195-
download_from_hf(CLIPTokenizer, version)
196-
download_from_hf(CLIPTextModel, version)
197-
198-
199-
# ---------------------------------------------
200-
def download_sd2_clip():
201-
version = "stabilityai/stable-diffusion-2"
202-
print("Installing SD2 clip model...", file=sys.stderr)
203-
download_from_hf(CLIPTokenizer, version, subfolder="tokenizer")
204-
download_from_hf(CLIPTextModel, version, subfolder="text_encoder")
187+
def download_conversion_models():
188+
target_dir = config.root_path / 'models/core/convert'
189+
kwargs = dict() # for future use
190+
try:
191+
logger.info('Downloading core tokenizers and text encoders')
205192

193+
# bert
194+
with warnings.catch_warnings():
195+
warnings.filterwarnings("ignore", category=DeprecationWarning)
196+
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
197+
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
198+
199+
# sd-1
200+
repo_id = 'openai/clip-vit-large-patch14'
201+
download_from_hf(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
202+
download_from_hf(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
203+
204+
# sd-2
205+
repo_id = "stabilityai/stable-diffusion-2"
206+
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
207+
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
208+
209+
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
210+
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
211+
212+
# VAE
213+
logger.info('Downloading stable diffusion VAE')
214+
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
215+
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
216+
217+
# safety checking
218+
logger.info('Downloading safety checker')
219+
repo_id = "CompVis/stable-diffusion-safety-checker"
220+
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
221+
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
222+
223+
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
224+
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
225+
except KeyboardInterrupt:
226+
raise
227+
except Exception as e:
228+
logger.error(str(e))
206229

207230
# ---------------------------------------------
208231
def download_realesrgan():
209-
print("Installing models from RealESRGAN...", file=sys.stderr)
232+
logger.info("Installing models from RealESRGAN...")
210233
model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
211234
wdn_model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth"
212235

213-
model_dest = config.root_path / "models/realesrgan/realesr-general-x4v3.pth"
214-
wdn_model_dest = config.root_path / "models/realesrgan/realesr-general-wdn-x4v3.pth"
236+
model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-x4v3.pth"
237+
wdn_model_dest = config.root_path / "models/core/upscaling/realesrgan/realesr-general-wdn-x4v3.pth"
215238

216239
download_with_progress_bar(model_url, str(model_dest), "RealESRGAN")
217240
download_with_progress_bar(wdn_model_url, str(wdn_model_dest), "RealESRGANwdn")
218241

219242

220243
def download_gfpgan():
221-
print("Installing GFPGAN models...", file=sys.stderr)
244+
logger.info("Installing GFPGAN models...")
222245
for model in (
223246
[
224247
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
225-
"./models/gfpgan/GFPGANv1.4.pth",
248+
"./models/core/face_restoration/gfpgan/GFPGANv1.4.pth",
226249
],
227250
[
228251
"https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
229-
"./models/gfpgan/weights/detection_Resnet50_Final.pth",
252+
"./models/core/face_restoration/gfpgan/weights/detection_Resnet50_Final.pth",
230253
],
231254
[
232255
"https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
233-
"./models/gfpgan/weights/parsing_parsenet.pth",
256+
"./models/core/face_restoration/gfpgan/weights/parsing_parsenet.pth",
234257
],
235258
):
236259
model_url, model_dest = model[0], config.root_path / model[1]
@@ -239,70 +262,32 @@ def download_gfpgan():
239262

240263
# ---------------------------------------------
241264
def download_codeformer():
242-
print("Installing CodeFormer model file...", file=sys.stderr)
265+
logger.info("Installing CodeFormer model file...")
243266
model_url = (
244267
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
245268
)
246-
model_dest = config.root_path / "models/codeformer/codeformer.pth"
269+
model_dest = config.root_path / "models/core/face_restoration/codeformer/codeformer.pth"
247270
download_with_progress_bar(model_url, str(model_dest), "CodeFormer")
248271

249272

250273
# ---------------------------------------------
251274
def download_clipseg():
252-
print("Installing clipseg model for text-based masking...", file=sys.stderr)
275+
logger.info("Installing clipseg model for text-based masking...")
253276
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
254277
try:
255-
download_from_hf(AutoProcessor, CLIPSEG_MODEL)
256-
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL)
278+
download_from_hf(AutoProcessor, CLIPSEG_MODEL, config.root_path / 'models/core/misc/clipseg')
279+
download_from_hf(CLIPSegForImageSegmentation, CLIPSEG_MODEL,'models/core/misc/clipseg')
257280
except Exception:
258-
print("Error installing clipseg model:")
259-
print(traceback.format_exc())
281+
logger.info("Error installing clipseg model:")
282+
logger.info(traceback.format_exc())
260283

261284

262-
# -------------------------------------
263-
def download_safety_checker():
264-
print("Installing model for NSFW content detection...", file=sys.stderr)
265-
try:
266-
from diffusers.pipelines.stable_diffusion.safety_checker import (
267-
StableDiffusionSafetyChecker,
268-
)
269-
from transformers import AutoFeatureExtractor
270-
except ModuleNotFoundError:
271-
print("Error installing NSFW checker model:")
272-
print(traceback.format_exc())
273-
return
274-
safety_model_id = "CompVis/stable-diffusion-safety-checker"
275-
print("AutoFeatureExtractor...", file=sys.stderr)
276-
download_from_hf(AutoFeatureExtractor, safety_model_id)
277-
print("StableDiffusionSafetyChecker...", file=sys.stderr)
278-
download_from_hf(StableDiffusionSafetyChecker, safety_model_id)
279-
280-
281-
# -------------------------------------
282-
def download_vaes():
283-
print("Installing stabilityai VAE...", file=sys.stderr)
284-
try:
285-
# first the diffusers version
286-
repo_id = "stabilityai/sd-vae-ft-mse"
287-
args = dict(
288-
cache_dir=config.cache_dir,
289-
)
290-
if not AutoencoderKL.from_pretrained(repo_id, **args):
291-
raise Exception(f"download of {repo_id} failed")
292-
293-
repo_id = "stabilityai/sd-vae-ft-mse-original"
294-
model_name = "vae-ft-mse-840000-ema-pruned.ckpt"
295-
# next the legacy checkpoint version
296-
if not hf_download_with_resume(
297-
repo_id=repo_id,
298-
model_name=model_name,
299-
model_dir=str(config.root_path / Model_dir / Weights_dir),
300-
):
301-
raise Exception(f"download of {model_name} failed")
302-
except Exception as e:
303-
print(f"Error downloading StabilityAI standard VAE: {str(e)}", file=sys.stderr)
304-
print(traceback.format_exc(), file=sys.stderr)
305-
285+
def download_support_models():
286+
download_realesrgan()
287+
download_gfpgan()
288+
download_codeformer()
289+
download_clipseg()
290+
download_conversion_models()
306291

307292
# -------------------------------------
308293
def get_root(root: str = None) -> str:
@@ -657,17 +642,13 @@ def default_user_selections(program_opts: Namespace) -> UserSelections:
657642

658643
# -------------------------------------
659644
def initialize_rootdir(root: Path, yes_to_all: bool = False):
660-
print("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
661-
645+
logger.info("** INITIALIZING INVOKEAI RUNTIME DIRECTORY **")
662646
for name in (
663647
"models",
664-
"configs",
665-
"embeddings",
666648
"databases",
667-
"loras",
668-
"controlnets",
669649
"text-inversion-output",
670650
"text-inversion-training-data",
651+
"configs"
671652
):
672653
os.makedirs(os.path.join(root, name), exist_ok=True)
673654

@@ -676,6 +657,22 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
676657
if not os.path.samefile(configs_src, configs_dest):
677658
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
678659

660+
dest = root / 'models'
661+
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
662+
for model_type in [ModelType.Pipeline, ModelType.Vae, ModelType.Lora,
663+
ModelType.ControlNet,ModelType.TextualInversion]:
664+
path = dest / model_base.value / model_type.value
665+
path.mkdir(parents=True, exist_ok=True)
666+
path = dest / 'core'
667+
path.mkdir(parents=True, exist_ok=True)
668+
669+
with open(root / 'configs' / 'models.yaml','w') as yaml_file:
670+
yaml_file.write(yaml.dump({'__metadata__':
671+
{'version':'3.0.0'}
672+
}
673+
)
674+
)
675+
679676

680677
# -------------------------------------
681678
def run_console_ui(
@@ -837,7 +834,7 @@ def main():
837834
old_init_file = config.root_path / 'invokeai.init'
838835
new_init_file = config.root_path / 'invokeai.yaml'
839836
if old_init_file.exists() and not new_init_file.exists():
840-
print('** Migrating invokeai.init to invokeai.yaml')
837+
logger.info('** Migrating invokeai.init to invokeai.yaml')
841838
migrate_init_file(old_init_file)
842839
# Load new init file into config
843840
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
@@ -855,29 +852,21 @@ def main():
855852
if init_options:
856853
write_opts(init_options, new_init_file)
857854
else:
858-
print(
855+
logger.info(
859856
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
860857
)
861858
sys.exit(0)
862859

863860
if opt.skip_support_models:
864-
print("\n** SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST **")
861+
logger.info("SKIPPING SUPPORT MODEL DOWNLOADS PER USER REQUEST")
865862
else:
866-
print("\n** CHECKING/UPDATING SUPPORT MODELS **")
867-
download_bert()
868-
download_sd1_clip()
869-
download_sd2_clip()
870-
download_realesrgan()
871-
download_gfpgan()
872-
download_codeformer()
873-
download_clipseg()
874-
download_safety_checker()
875-
download_vaes()
863+
logger.info("CHECKING/UPDATING SUPPORT MODELS")
864+
download_support_models()
876865

877866
if opt.skip_sd_weights:
878-
print("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
867+
logger.info("\n** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
879868
elif models_to_download:
880-
print("\n** DOWNLOADING DIFFUSION WEIGHTS **")
869+
logger.info("\n** DOWNLOADING DIFFUSION WEIGHTS **")
881870
process_and_execute(opt, models_to_download)
882871

883872
postscript(errors=errors)

0 commit comments

Comments
 (0)