44import os
55from http .server import BaseHTTPRequestHandler , ThreadingHTTPServer
66from ldm .dream .pngwriter import PngWriter
7+ from threading import Event
8+
9+ class CanceledException (Exception ):
10+ pass
711
812class DreamServer (BaseHTTPRequestHandler ):
913 model = None
14+ canceled = Event ()
1015
1116 def do_GET (self ):
1217 if self .path == "/" :
@@ -25,6 +30,12 @@ def do_GET(self):
2530 'gfpgan_model_exists' : gfpgan_model_exists
2631 }
2732 self .wfile .write (bytes ("let config = " + json .dumps (config ) + ";\n " , "utf-8" ))
33+ elif self .path == "/cancel" :
34+ self .canceled .set ()
35+ self .send_response (200 )
36+ self .send_header ("Content-type" , "application/json" )
37+ self .end_headers ()
38+ self .wfile .write (bytes ('{}' , 'utf8' ))
2839 else :
2940 path = "." + self .path
3041 cwd = os .path .realpath (os .getcwd ())
@@ -67,6 +78,7 @@ def do_POST(self):
6778 progress_images = 'progress_images' in post_data
6879 seed = self .model .seed if int (post_data ['seed' ]) == - 1 else int (post_data ['seed' ])
6980
81+ self .canceled .clear ()
7082 print (f"Request to generate with prompt: { prompt } " )
7183 # In order to handle upscaled images, the PngWriter needs to maintain state
7284 # across images generated by each call to prompt2img(), so we define it in
@@ -121,6 +133,9 @@ def image_done(image, seed, upscaled=False):
121133 # it doesn't need to know if batch_size > 1, just if this is _part of a batch_
122134 step_writer = PngWriter ('./outputs/intermediates/' , prompt , 2 )
123135 def image_progress (sample , step ):
136+ if self .canceled .is_set ():
137+ self .wfile .write (bytes (json .dumps ({'event' :'canceled' }) + '\n ' , 'utf-8' ))
138+ raise CanceledException
124139 url = None
125140 # since rendering images is moderately expensive, only render every 5th image
126141 # and don't bother with the last one, since it'll render anyway
@@ -133,41 +148,46 @@ def image_progress(sample, step):
133148 {'event' :'step' , 'step' :step , 'url' : url }
134149 ) + '\n ' ,"utf-8" ))
135150
136- if initimg is None :
137- # Run txt2img
138- self .model .prompt2image (prompt ,
139- iterations = iterations ,
140- cfg_scale = cfgscale ,
141- width = width ,
142- height = height ,
143- seed = seed ,
144- steps = steps ,
145- gfpgan_strength = gfpgan_strength ,
146- upscale = upscale ,
147- sampler_name = sampler_name ,
148- step_callback = image_progress ,
149- image_callback = image_done )
150- else :
151- # Decode initimg as base64 to temp file
152- with open ("./img2img-tmp.png" , "wb" ) as f :
153- initimg = initimg .split ("," )[1 ] # Ignore mime type
154- f .write (base64 .b64decode (initimg ))
155-
156- # Run img2img
157- self .model .prompt2image (prompt ,
158- init_img = "./img2img-tmp.png" ,
159- iterations = iterations ,
160- cfg_scale = cfgscale ,
161- seed = seed ,
162- steps = steps ,
163- sampler_name = sampler_name ,
164- gfpgan_strength = gfpgan_strength ,
165- upscale = upscale ,
166- step_callback = image_progress ,
167- image_callback = image_done )
168-
169- # Remove the temp file
170- os .remove ("./img2img-tmp.png" )
151+ try :
152+ if initimg is None :
153+ # Run txt2img
154+ self .model .prompt2image (prompt ,
155+ iterations = iterations ,
156+ cfg_scale = cfgscale ,
157+ width = width ,
158+ height = height ,
159+ seed = seed ,
160+ steps = steps ,
161+ gfpgan_strength = gfpgan_strength ,
162+ upscale = upscale ,
163+ sampler_name = sampler_name ,
164+ step_callback = image_progress ,
165+ image_callback = image_done )
166+ else :
167+ # Decode initimg as base64 to temp file
168+ with open ("./img2img-tmp.png" , "wb" ) as f :
169+ initimg = initimg .split ("," )[1 ] # Ignore mime type
170+ f .write (base64 .b64decode (initimg ))
171+
172+ try :
173+ # Run img2img
174+ self .model .prompt2image (prompt ,
175+ init_img = "./img2img-tmp.png" ,
176+ iterations = iterations ,
177+ cfg_scale = cfgscale ,
178+ seed = seed ,
179+ steps = steps ,
180+ sampler_name = sampler_name ,
181+ gfpgan_strength = gfpgan_strength ,
182+ upscale = upscale ,
183+ step_callback = image_progress ,
184+ image_callback = image_done )
185+ finally :
186+ # Remove the temp file
187+ os .remove ("./img2img-tmp.png" )
188+ except CanceledException :
189+ print (f"Canceled." )
190+ return
171191
172192 print (f"Prompt generated!" )
173193
0 commit comments