66import numpy as np
77import random
88import os
9+ import traceback
910from tqdm import tqdm , trange
1011from PIL import Image , ImageFilter
1112from einops import rearrange , repeat
@@ -43,14 +44,15 @@ def set_variation(self, seed, variation_amount, with_variations):
4344 self .variation_amount = variation_amount
4445 self .with_variations = with_variations
4546
46- def generate (self ,prompt ,init_image ,width ,height ,iterations = 1 ,seed = None ,
47+ def generate (self ,prompt ,init_image ,width ,height ,sampler , iterations = 1 ,seed = None ,
4748 image_callback = None , step_callback = None , threshold = 0.0 , perlin = 0.0 ,
4849 safety_checker :dict = None ,
4950 ** kwargs ):
5051 scope = choose_autocast (self .precision )
5152 self .safety_checker = safety_checker
5253 make_image = self .get_make_image (
5354 prompt ,
55+ sampler = sampler ,
5456 init_image = init_image ,
5557 width = width ,
5658 height = height ,
@@ -59,12 +61,14 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
5961 perlin = perlin ,
6062 ** kwargs
6163 )
62-
6364 results = []
6465 seed = seed if seed is not None else self .new_seed ()
6566 first_seed = seed
6667 seed , initial_noise = self .generate_initial_noise (seed , width , height )
67- with scope (self .model .device .type ), self .model .ema_scope ():
68+
69+ # There used to be an additional self.model.ema_scope() here, but it breaks
70+ # the inpaint-1.5 model. Not sure what it did.... ?
71+ with scope (self .model .device .type ):
6872 for n in trange (iterations , desc = 'Generating' ):
6973 x_T = None
7074 if self .variation_amount > 0 :
@@ -79,7 +83,8 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
7983 try :
8084 x_T = self .get_noise (width ,height )
8185 except :
82- pass
86+ print ('** An error occurred while getting initial noise **' )
87+ print (traceback .format_exc ())
8388
8489 image = make_image (x_T )
8590
@@ -95,10 +100,10 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
95100
96101 return results
97102
98- def sample_to_image (self ,samples ):
103+ def sample_to_image (self ,samples )-> Image . Image :
99104 """
100- Returns a function returning an image derived from the prompt and the initial image
101- Return value depends on the seed at the time you call it
105+ Given samples returned from a sampler, converts
106+ it into a PIL Image
102107 """
103108 x_samples = self .model .decode_first_stage (samples )
104109 x_samples = torch .clamp ((x_samples + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
0 commit comments