2020from multiprocessing .connection import Connection , Pipe
2121from pathlib import Path
2222from shutil import get_terminal_size
23+ from typing import Optional
2324
2425import npyscreen
2526import torch
@@ -630,21 +631,23 @@ def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None)
630631 return _ask_user_for_pt_cmdline (model_path )
631632
632633
633- def _ask_user_for_pt_cmdline (model_path : Path ) -> SchedulerPredictionType :
634+ def _ask_user_for_pt_cmdline (model_path : Path ) -> Optional [ SchedulerPredictionType ] :
634635 choices = [SchedulerPredictionType .Epsilon , SchedulerPredictionType .VPrediction , None ]
635636 print (
636637 f"""
637- Please select the type of the V2 checkpoint named { model_path .name } :
638- [1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base)
639- [2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768)
640- [3] Skip this model and come back later.
638+ Please select the scheduler prediction type of the checkpoint named { model_path .name } :
639+ [1] "epsilon" - most v1.5 models and v2 models trained on 512 pixel images
640+ [2] "vprediction" - v2 models trained on 768 pixel images and a few v1.5 models
641+ [3] Accept the best guess; you can fix it in the Web UI later
641642"""
642643 )
643644 choice = None
644645 ok = False
645646 while not ok :
646647 try :
647- choice = input ("select> " ).strip ()
648+ choice = input ("select [3]> " ).strip ()
649+ if not choice :
650+ return None
648651 choice = choices [int (choice ) - 1 ]
649652 ok = True
650653 except (ValueError , IndexError ):
@@ -655,22 +658,18 @@ def _ask_user_for_pt_cmdline(model_path: Path) -> SchedulerPredictionType:
655658
656659
657660def _ask_user_for_pt_tui (model_path : Path , tui_conn : Connection ) -> SchedulerPredictionType :
658- try :
659- tui_conn .send_bytes (f"*need v2 config for:{ model_path } " .encode ("utf-8" ))
660- # note that we don't do any status checking here
661- response = tui_conn .recv_bytes ().decode ("utf-8" )
662- if response is None :
663- return None
664- elif response == "epsilon" :
665- return SchedulerPredictionType .epsilon
666- elif response == "v" :
667- return SchedulerPredictionType .VPrediction
668- elif response == "abort" :
669- logger .info ("Conversion aborted" )
670- return None
671- else :
672- return response
673- except Exception :
661+ tui_conn .send_bytes (f"*need v2 config for:{ model_path } " .encode ("utf-8" ))
662+ # note that we don't do any status checking here
663+ response = tui_conn .recv_bytes ().decode ("utf-8" )
664+ if response is None :
665+ return None
666+ elif response == "epsilon" :
667+ return SchedulerPredictionType .epsilon
668+ elif response == "v" :
669+ return SchedulerPredictionType .VPrediction
670+ elif response == "guess" :
671+ return None
672+ else :
674673 return None
675674
676675
0 commit comments