Skip to content

Commit 9b71597

Browse files
committed
resolve conflicts between PR invoke-ai#1108 and invoke-ai#1243
2 parents 2f1c1e7 + b1da13a commit 9b71597

17 files changed

Lines changed: 444 additions & 38 deletions

File tree

configs/models.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ stable-diffusion-1.4:
1313
width: 512
1414
height: 512
1515
default: true
16+
inpainting-1.5:
17+
description: runwayML tuned inpainting model v1.5
18+
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
19+
config: configs/stable-diffusion/v1-inpainting-inference.yaml
20+
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
21+
width: 512
22+
height: 512
1623
stable-diffusion-1.5:
1724
config: configs/stable-diffusion/v1-inference.yaml
1825
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
model:
2+
base_learning_rate: 7.5e-05
3+
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4+
params:
5+
linear_start: 0.00085
6+
linear_end: 0.0120
7+
num_timesteps_cond: 1
8+
log_every_t: 200
9+
timesteps: 1000
10+
first_stage_key: "jpg"
11+
cond_stage_key: "txt"
12+
image_size: 64
13+
channels: 4
14+
cond_stage_trainable: false # Note: different from the one we trained before
15+
conditioning_key: hybrid # important
16+
monitor: val/loss_simple_ema
17+
scale_factor: 0.18215
18+
finetune_keys: null
19+
20+
scheduler_config: # 10000 warmup steps
21+
target: ldm.lr_scheduler.LambdaLinearScheduler
22+
params:
23+
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25+
f_start: [ 1.e-6 ]
26+
f_max: [ 1. ]
27+
f_min: [ 1. ]
28+
29+
personalization_config:
30+
target: ldm.modules.embedding_manager.EmbeddingManager
31+
params:
32+
placeholder_strings: ["*"]
33+
initializer_words: ['face', 'man', 'photo', 'africanmale']
34+
per_image_tokens: false
35+
num_vectors_per_token: 1
36+
progressive_words: False
37+
38+
unet_config:
39+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
40+
params:
41+
image_size: 32 # unused
42+
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
43+
out_channels: 4
44+
model_channels: 320
45+
attention_resolutions: [ 4, 2, 1 ]
46+
num_res_blocks: 2
47+
channel_mult: [ 1, 2, 4, 4 ]
48+
num_heads: 8
49+
use_spatial_transformer: True
50+
transformer_depth: 1
51+
context_dim: 768
52+
use_checkpoint: True
53+
legacy: False
54+
55+
first_stage_config:
56+
target: ldm.models.autoencoder.AutoencoderKL
57+
params:
58+
embed_dim: 4
59+
monitor: val/rec_loss
60+
ddconfig:
61+
double_z: true
62+
z_channels: 4
63+
resolution: 256
64+
in_channels: 3
65+
out_ch: 3
66+
ch: 128
67+
ch_mult:
68+
- 1
69+
- 2
70+
- 4
71+
- 4
72+
num_res_blocks: 2
73+
attn_resolutions: []
74+
dropout: 0.0
75+
lossconfig:
76+
target: torch.nn.Identity
77+
78+
cond_stage_config:
79+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

ldm/generate.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,10 @@ def process_image(image,seed):
421421
)
422422

423423
# TODO: Hacky selection of operation to perform. Needs to be refactored.
424-
if (init_image is not None) and (mask_image is not None):
424+
if self.sampler.conditioning_key() in ('hybrid','concat'):
425+
print(f'** Inpainting model detected. Will try it! **')
426+
generator = self._make_omnibus()
427+
elif (init_image is not None) and (mask_image is not None):
425428
generator = self._make_inpaint()
426429
elif (embiggen != None or embiggen_tiles != None):
427430
generator = self._make_embiggen()
@@ -677,6 +680,7 @@ def _make_images(
677680

678681
return init_image,init_mask
679682

683+
# lots o' repeated code here! Turn into a make_func()
680684
def _make_base(self):
681685
if not self.generators.get('base'):
682686
from ldm.invoke.generator import Generator
@@ -687,6 +691,7 @@ def _make_img2img(self):
687691
if not self.generators.get('img2img'):
688692
from ldm.invoke.generator.img2img import Img2Img
689693
self.generators['img2img'] = Img2Img(self.model, self.precision)
694+
self.generators['img2img'].free_gpu_mem = self.free_gpu_mem
690695
return self.generators['img2img']
691696

692697
def _make_embiggen(self):
@@ -715,6 +720,15 @@ def _make_inpaint(self):
715720
self.generators['inpaint'] = Inpaint(self.model, self.precision)
716721
return self.generators['inpaint']
717722

723+
# "omnibus" supports the runwayML custom inpainting model, which does
724+
# txt2img, img2img and inpainting using slight variations on the same code
725+
def _make_omnibus(self):
726+
if not self.generators.get('omnibus'):
727+
from ldm.invoke.generator.omnibus import Omnibus
728+
self.generators['omnibus'] = Omnibus(self.model, self.precision)
729+
self.generators['omnibus'].free_gpu_mem = self.free_gpu_mem
730+
return self.generators['omnibus']
731+
718732
def load_model(self):
719733
'''
720734
preload model identified in self.model_name

ldm/invoke/args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def parse_cmd(self,cmd_string):
181181
switches_started = False
182182

183183
for element in elements:
184-
if element[0] == '-' and not switches_started:
184+
if len(element) == 0: # empty prompt
185+
pass
186+
elif element[0] == '-' and not switches_started:
185187
switches_started = True
186188
if switches_started:
187189
switches.append(element)

ldm/invoke/conditioning.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
123123
else:
124124
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
125125

126-
127126
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
127+
conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
128128
return (
129129
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
130130
cross_attention_control_args=cac_args
@@ -166,4 +166,25 @@ def get_tokens_length(model, fragments: list[Fragment]):
166166
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
167167
return sum([len(x) for x in tokens])
168168

169+
def flatten_hybrid_conditioning(uncond, cond):
170+
'''
171+
This handles the choice between a conditional conditioning
172+
that is a tensor (used by cross attention) vs one that has additional
173+
dimensions as well, as used by 'hybrid'
174+
'''
175+
if isinstance(cond, dict):
176+
assert isinstance(uncond, dict)
177+
cond_in = dict()
178+
for k in cond:
179+
if isinstance(cond[k], list):
180+
cond_in[k] = [
181+
torch.cat([uncond[k][i], cond[k][i]])
182+
for i in range(len(cond[k]))
183+
]
184+
else:
185+
cond_in[k] = torch.cat([uncond[k], cond[k]])
186+
return cond_in
187+
else:
188+
return cond
169189

190+

ldm/invoke/generator/base.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import random
88
import os
9+
import traceback
910
from tqdm import tqdm, trange
1011
from PIL import Image, ImageFilter
1112
from einops import rearrange, repeat
@@ -43,14 +44,15 @@ def set_variation(self, seed, variation_amount, with_variations):
4344
self.variation_amount = variation_amount
4445
self.with_variations = with_variations
4546

46-
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
47+
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
4748
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
4849
safety_checker:dict=None,
4950
**kwargs):
5051
scope = choose_autocast(self.precision)
5152
self.safety_checker = safety_checker
5253
make_image = self.get_make_image(
5354
prompt,
55+
sampler = sampler,
5456
init_image = init_image,
5557
width = width,
5658
height = height,
@@ -59,12 +61,14 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
5961
perlin = perlin,
6062
**kwargs
6163
)
62-
6364
results = []
6465
seed = seed if seed is not None else self.new_seed()
6566
first_seed = seed
6667
seed, initial_noise = self.generate_initial_noise(seed, width, height)
67-
with scope(self.model.device.type), self.model.ema_scope():
68+
69+
# There used to be an additional self.model.ema_scope() here, but it breaks
70+
# the inpaint-1.5 model. Not sure what it did.... ?
71+
with scope(self.model.device.type):
6872
for n in trange(iterations, desc='Generating'):
6973
x_T = None
7074
if self.variation_amount > 0:
@@ -79,7 +83,8 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
7983
try:
8084
x_T = self.get_noise(width,height)
8185
except:
82-
pass
86+
print('** An error occurred while getting initial noise **')
87+
print(traceback.format_exc())
8388

8489
image = make_image(x_T)
8590

@@ -95,10 +100,10 @@ def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
95100

96101
return results
97102

98-
def sample_to_image(self,samples):
103+
def sample_to_image(self,samples)->Image.Image:
99104
"""
100-
Returns a function returning an image derived from the prompt and the initial image
101-
Return value depends on the seed at the time you call it
105+
Given samples returned from a sampler, converts
106+
it into a PIL Image
102107
"""
103108
x_samples = self.model.decode_first_stage(samples)
104109
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

ldm/invoke/generator/img2img.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class Img2Img(Generator):
1616
def __init__(self, model, precision):
1717
super().__init__(model, precision)
18-
self.init_latent = None # by get_noise()
18+
self.init_latent = None # by get_noise()
1919

2020
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
2121
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,**kwargs):
@@ -80,7 +80,10 @@ def get_noise(self,width,height):
8080

8181
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
8282
image = np.array(image).astype(np.float32) / 255.0
83-
image = image[None].transpose(0, 3, 1, 2)
83+
if len(image.shape) == 2: # 'L' image, as in a mask
84+
image = image[None,None]
85+
else: # 'RGB' image
86+
image = image[None].transpose(0, 3, 1, 2)
8487
image = torch.from_numpy(image)
8588
if normalize:
8689
image = 2.0 * image - 1.0

0 commit comments

Comments
 (0)