Skip to content

Commit 5d13207

Browse files
committed
webui: support cancelation
1 parent dae2b26 commit 5d13207

4 files changed

Lines changed: 73 additions & 39 deletions

File tree

ldm/dream/server.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
import os
55
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
66
from ldm.dream.pngwriter import PngWriter
7+
from threading import Event
8+
9+
class CanceledException(Exception):
10+
pass
711

812
class DreamServer(BaseHTTPRequestHandler):
913
model = None
14+
canceled = Event()
1015

1116
def do_GET(self):
1217
if self.path == "/":
@@ -25,6 +30,12 @@ def do_GET(self):
2530
'gfpgan_model_exists': gfpgan_model_exists
2631
}
2732
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
33+
elif self.path == "/cancel":
34+
self.canceled.set()
35+
self.send_response(200)
36+
self.send_header("Content-type", "application/json")
37+
self.end_headers()
38+
self.wfile.write(bytes('{}', 'utf8'))
2839
else:
2940
path = "." + self.path
3041
cwd = os.path.realpath(os.getcwd())
@@ -67,6 +78,7 @@ def do_POST(self):
6778
progress_images = 'progress_images' in post_data
6879
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
6980

81+
self.canceled.clear()
7082
print(f"Request to generate with prompt: {prompt}")
7183
# In order to handle upscaled images, the PngWriter needs to maintain state
7284
# across images generated by each call to prompt2img(), so we define it in
@@ -121,6 +133,9 @@ def image_done(image, seed, upscaled=False):
121133
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
122134
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
123135
def image_progress(sample, step):
136+
if self.canceled.is_set():
137+
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
138+
raise CanceledException
124139
url = None
125140
# since rendering images is moderately expensive, only render every 5th image
126141
# and don't bother with the last one, since it'll render anyway
@@ -133,41 +148,46 @@ def image_progress(sample, step):
133148
{'event':'step', 'step':step, 'url': url}
134149
) + '\n',"utf-8"))
135150

136-
if initimg is None:
137-
# Run txt2img
138-
self.model.prompt2image(prompt,
139-
iterations=iterations,
140-
cfg_scale = cfgscale,
141-
width = width,
142-
height = height,
143-
seed = seed,
144-
steps = steps,
145-
gfpgan_strength = gfpgan_strength,
146-
upscale = upscale,
147-
sampler_name = sampler_name,
148-
step_callback=image_progress,
149-
image_callback=image_done)
150-
else:
151-
# Decode initimg as base64 to temp file
152-
with open("./img2img-tmp.png", "wb") as f:
153-
initimg = initimg.split(",")[1] # Ignore mime type
154-
f.write(base64.b64decode(initimg))
155-
156-
# Run img2img
157-
self.model.prompt2image(prompt,
158-
init_img = "./img2img-tmp.png",
159-
iterations = iterations,
160-
cfg_scale = cfgscale,
161-
seed = seed,
162-
steps = steps,
163-
sampler_name = sampler_name,
164-
gfpgan_strength=gfpgan_strength,
165-
upscale = upscale,
166-
step_callback=image_progress,
167-
image_callback=image_done)
168-
169-
# Remove the temp file
170-
os.remove("./img2img-tmp.png")
151+
try:
152+
if initimg is None:
153+
# Run txt2img
154+
self.model.prompt2image(prompt,
155+
iterations=iterations,
156+
cfg_scale = cfgscale,
157+
width = width,
158+
height = height,
159+
seed = seed,
160+
steps = steps,
161+
gfpgan_strength = gfpgan_strength,
162+
upscale = upscale,
163+
sampler_name = sampler_name,
164+
step_callback=image_progress,
165+
image_callback=image_done)
166+
else:
167+
# Decode initimg as base64 to temp file
168+
with open("./img2img-tmp.png", "wb") as f:
169+
initimg = initimg.split(",")[1] # Ignore mime type
170+
f.write(base64.b64decode(initimg))
171+
172+
try:
173+
# Run img2img
174+
self.model.prompt2image(prompt,
175+
init_img = "./img2img-tmp.png",
176+
iterations = iterations,
177+
cfg_scale = cfgscale,
178+
seed = seed,
179+
steps = steps,
180+
sampler_name = sampler_name,
181+
gfpgan_strength=gfpgan_strength,
182+
upscale = upscale,
183+
step_callback=image_progress,
184+
image_callback=image_done)
185+
finally:
186+
# Remove the temp file
187+
os.remove("./img2img-tmp.png")
188+
except CanceledException:
189+
print(f"Canceled.")
190+
return
171191

172192
print(f"Prompt generated!")
173193

static/dream_web/index.css

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,7 @@ label {
7474
width: 30vh;
7575
height: 30vh;
7676
}
77+
#cancel-button {
78+
cursor: pointer;
79+
color: red;
80+
}

static/dream_web/index.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ <h2 id="header">Stable Diffusion Dream Server</h2>
8787
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
8888
<div id="progress-section">
8989
<progress id="progress-bar" value="0" max="1"></progress>
90+
<span id="cancel-button" title="Cancel">&#10006;</span>
9091
<br>
9192
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'></img>
9293
<div id="scaling-inprocess-message">

static/dream_web/index.js

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,26 @@ async function generateSubmit(form) {
8989
for (let event of value.split('\n').filter(e => e !== '')) {
9090
const data = JSON.parse(event);
9191

92-
if (data.event == 'result') {
92+
if (data.event === 'result') {
9393
noOutputs = false;
9494
document.querySelector("#no-results-message")?.remove();
9595
appendOutput(data.files[0],data.files[1],data.config);
9696
progressEle.setAttribute('value', 0);
9797
progressEle.setAttribute('max', formData.steps);
9898
progressImageEle.src = BLANK_IMAGE_URL;
99-
} else if (data.event == 'upscaling-started') {
99+
} else if (data.event === 'upscaling-started') {
100100
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
101101
document.getElementById("scaling-inprocess-message").style.display = "block";
102-
} else if (data.event == 'upscaling-done') {
102+
} else if (data.event === 'upscaling-done') {
103103
document.getElementById("scaling-inprocess-message").style.display = "none";
104-
} else if (data.event == 'step') {
104+
} else if (data.event === 'step') {
105105
progressEle.setAttribute('value', data.step);
106106
if (data.url) {
107107
progressImageEle.src = data.url;
108108
}
109+
} else if (data.event === 'canceled') {
110+
// avoid alerting as if this were an error case
111+
noOutputs = false;
109112
}
110113
}
111114
}
@@ -144,6 +147,12 @@ window.onload = () => {
144147
});
145148
loadFields(document.querySelector("#generate-form"));
146149

150+
document.querySelector('#cancel-button').addEventListener('click', () => {
151+
fetch('/cancel').catch(e => {
152+
console.error(e);
153+
});
154+
});
155+
147156
if (!config.gfpgan_model_exists) {
148157
document.querySelector("#gfpgan").style.display = 'none';
149158
}

0 commit comments

Comments
 (0)