Skip to content

Commit 76633f5

Browse files
ebrlstein
authored andcommitted
(config) make user aware of any problems downloading models
also implement a generic way of reporting issues at the end of installation
1 parent ed61943 commit 76633f5

1 file changed

Lines changed: 22 additions & 7 deletions

File tree

scripts/configure_invokeai.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from omegaconf import OmegaConf
1919
from huggingface_hub import HfFolder, hf_hub_url
2020
from pathlib import Path
21+
from typing import Union
2122
from getpass_asterisk import getpass_asterisk
2223
from transformers import CLIPTokenizer, CLIPTextModel
2324
from ldm.invoke.globals import Globals
@@ -62,9 +63,9 @@ def introduction():
6263
)
6364

6465
#--------------------------------------------
65-
def postscript():
66-
print(
67-
'''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands:
66+
def postscript(errors: None):
67+
if not any(errors):
68+
message='''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands:
6869
Web version:
6970
python scripts/invoke.py --web (connect to http://localhost:9090)
7071
Command-line version:
@@ -77,7 +78,14 @@ def postscript():
7778
7879
Have fun!
7980
'''
80-
)
81+
82+
else:
83+
message=f"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
84+
for err in errors:
85+
message += f"\t - {err}\n"
86+
message += "Please check the logs above and correct any issues."
87+
88+
print(message)
8189

8290
#---------------------------------------------
8391
def yes_or_no(prompt:str, default_yes=True):
@@ -521,6 +529,7 @@ def download_safety_checker():
521529
print('...success',file=sys.stderr)
522530

523531
#-------------------------------------
532+
def download_weights(opt:dict) -> Union[str, None]:
524533
# Authenticate to Huggingface using environment variables.
525534
# If successful, authentication will persist for either interactive or non-interactive use.
526535
# Default env var expected by HuggingFace is HUGGING_FACE_HUB_TOKEN.
@@ -537,7 +546,8 @@ def download_safety_checker():
537546
return
538547
else:
539548
print('** Cannot download models because no Hugging Face access token could be found. Please re-run without --yes')
540-
return
549+
return "could not download model weights from Huggingface due to missing or invalid access token"
550+
541551
else:
542552
choice = user_wants_to_download_weights()
543553

@@ -558,6 +568,8 @@ def download_safety_checker():
558568
print('\n** DOWNLOADING WEIGHTS **')
559569
successfully_downloaded = download_weight_datasets(models, access_token)
560570
update_config_file(successfully_downloaded,opt)
571+
if len(successfully_downloaded) < len(models):
572+
return "some of the model weights downloads were not successful"
561573

562574
#-------------------------------------
563575
def get_root(root:str=None)->str:
@@ -746,9 +758,12 @@ def main():
746758
or not os.path.exists(os.path.join(Globals.root,'configs/stable-diffusion/v1-inference.yaml')):
747759
initialize_rootdir(Globals.root,opt.yes_to_all)
748760

761+
# Optimistically try to download all required assets. If any errors occur, add them and proceed anyway.
762+
errors=set()
763+
749764
if opt.interactive:
750765
print('** DOWNLOADING DIFFUSION WEIGHTS **')
751-
download_weights(opt)
766+
errors.add(download_weights(opt))
752767
print('\n** DOWNLOADING SUPPORT MODELS **')
753768
download_bert()
754769
download_clip()
@@ -757,7 +772,7 @@ def main():
757772
download_codeformer()
758773
download_clipseg()
759774
download_safety_checker()
760-
postscript()
775+
postscript(errors=errors)
761776
except KeyboardInterrupt:
762777
print('\nGoodbye! Come back soon.')
763778
except Exception as e:

0 commit comments

Comments
 (0)