Skip to content

Commit f3292a6

Browse files
Implement CodeFormer Face Restoration (invoke-ai#669)
* Implement CodeFormer Face Restoration * fix codeformer model destination path Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 062f3e8 commit f3292a6

8 files changed

Lines changed: 871 additions & 5 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# ignore default image save location and model symbolic link
22
outputs/
33
models/ldm/stable-diffusion-v1/model.ckpt
4+
ldm/restoration/codeformer/weights
45

56
# ignore a directory which serves as a place for initial images
67
inputs/

docs/features/UPSCALE.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,39 @@ the base images.
9797
If you wish to stop during the image generation but want to upscale or face restore a particular
9898
generated image, pass it again with the same prompt and generated seed along with the `-U` and `-G`
9999
prompt arguments to perform those actions.
100+
101+
## CodeFormer Support
102+
103+
This repo also allows you to perform face restoration using
104+
[CodeFormer](https://github.com/sczhou/CodeFormer).
105+
106+
In order to setup CodeFormer to work, you need to download the models like with GFPGAN. You can do
107+
this either by running `preload_models.py` or by manually downloading the
108+
[model file](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth) and
109+
saving it to `ldm/restoration/codeformer/weights` folder.
110+
111+
You can use `-ft` prompt argument to swap between CodeFormer and the default GFPGAN. The above
112+
mentioned `-G` prompt argument will allow you to control the strength of the restoration effect.
113+
114+
### **Usage:**
115+
116+
The following command will perform face restoration with CodeFormer instead of the default gfpgan.
117+
118+
`<prompt> -G 0.8 -ft codeformer`
119+
120+
**Other Options:**
121+
122+
- `-cf` - cf or CodeFormer Fidelity takes values between `0` and `1`. 0 produces high quality
123+
results but low accuracy and 1 produces lower quality results but higher accuacy to your original
124+
face.
125+
126+
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
127+
that is closely matching to the input face.
128+
129+
`<prompt> -G 1.0 -ft codeformer -cf 0.9`
130+
131+
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
132+
that is the best restoration possible. This may deviate slightly from the original face. This is an
133+
excellent option to use in situations when there is very little facial data to work with.
134+
135+
`<prompt> -G 1.0 -ft codeformer -cf 0.1`

ldm/dream/args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,26 @@ def _create_dream_cmd_parser(self):
516516
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
517517
default=0.75,
518518
)
519+
postprocessing_group.add_argument(
520+
'-ft',
521+
'--facetool',
522+
type=str,
523+
help='Select the face restoration AI to use: gfpgan, codeformer',
524+
)
519525
postprocessing_group.add_argument(
520526
'-G',
521527
'--gfpgan_strength',
522528
type=float,
523529
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
524530
default=0,
525531
)
532+
postprocessing_group.add_argument(
533+
'-cf',
534+
'--codeformer_fidelity',
535+
type=float,
536+
help='Takes values between 0 and 1. 0 produces high quality but low accuracy. 1 produces high accuracy but low quality.',
537+
default=0.75
538+
)
526539
postprocessing_group.add_argument(
527540
'-U',
528541
'--upscale',

ldm/generate.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def prompt2image(
227227
embiggen = None,
228228
embiggen_tiles = None,
229229
# these are specific to GFPGAN/ESRGAN
230+
facetool = None,
230231
gfpgan_strength = 0,
232+
codeformer_fidelity = None,
231233
save_original = False,
232234
upscale = None,
233235
# Set this True to handle KeyboardInterrupt internally
@@ -373,7 +375,9 @@ def process_image(image,seed):
373375
if upscale is not None or gfpgan_strength > 0:
374376
self.upscale_and_reconstruct(results,
375377
upscale = upscale,
378+
facetool = facetool,
376379
strength = gfpgan_strength,
380+
codeformer_fidelity = codeformer_fidelity,
377381
save_original = save_original,
378382
image_callback = image_callback)
379383

@@ -507,15 +511,20 @@ def correct_colors(self,
507511

508512
def upscale_and_reconstruct(self,
509513
image_list,
514+
facetool = 'gfpgan',
510515
upscale = None,
511516
strength = 0.0,
517+
codeformer_fidelity = 0.75,
512518
save_original = False,
513519
image_callback = None):
514520
try:
515521
if upscale is not None:
516522
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
517523
if strength > 0:
518-
from ldm.gfpgan.gfpgan_tools import run_gfpgan
524+
if facetool == 'codeformer':
525+
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
526+
else:
527+
from ldm.gfpgan.gfpgan_tools import run_gfpgan
519528
except (ModuleNotFoundError, ImportError):
520529
print(traceback.format_exc(), file=sys.stderr)
521530
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
@@ -534,9 +543,12 @@ def upscale_and_reconstruct(self,
534543
seed,
535544
)
536545
if strength > 0:
537-
image = run_gfpgan(
538-
image, strength, seed, 1
539-
)
546+
if facetool == 'codeformer':
547+
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
548+
else:
549+
image = run_gfpgan(
550+
image, strength, seed, 1
551+
)
540552
except Exception as e:
541553
print(
542554
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import warnings
5+
6+
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
7+
8+
class CodeFormerRestoration():
9+
def __init__(self) -> None:
10+
pass
11+
12+
def process(self, image, strength, device, seed=None, fidelity=0.75):
13+
if seed is not None:
14+
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}')
15+
with warnings.catch_warnings():
16+
warnings.filterwarnings('ignore', category=DeprecationWarning)
17+
warnings.filterwarnings('ignore', category=UserWarning)
18+
19+
from basicsr.utils.download_util import load_file_from_url
20+
from basicsr.utils import img2tensor, tensor2img
21+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
22+
from ldm.restoration.codeformer.codeformer_arch import CodeFormer
23+
from torchvision.transforms.functional import normalize
24+
from PIL import Image
25+
26+
cf_class = CodeFormer
27+
28+
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
29+
30+
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/restoration/codeformer/weights'), progress=True)
31+
checkpoint = torch.load(checkpoint_path)['params_ema']
32+
cf.load_state_dict(checkpoint)
33+
cf.eval()
34+
35+
image = image.convert('RGB')
36+
37+
face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device)
38+
face_helper.clean_all()
39+
face_helper.read_image(np.array(image, dtype=np.uint8))
40+
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
41+
face_helper.align_warp_face()
42+
43+
for idx, cropped_face in enumerate(face_helper.cropped_faces):
44+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
45+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
46+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
47+
48+
try:
49+
with torch.no_grad():
50+
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
51+
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
52+
del output
53+
torch.cuda.empty_cache()
54+
except RuntimeError as error:
55+
print(f'\tFailed inference for CodeFormer: {error}.')
56+
restored_face = cropped_face
57+
58+
restored_face = restored_face.astype('uint8')
59+
face_helper.add_restored_face(restored_face)
60+
61+
62+
face_helper.get_inverse_affine(None)
63+
64+
restored_img = face_helper.paste_faces_to_input_image()
65+
66+
res = Image.fromarray(restored_img)
67+
68+
if strength < 1.0:
69+
# Resize the image to the new image if the sizes have changed
70+
if restored_img.size != image.size:
71+
image = image.resize(res.size)
72+
res = Image.blend(image, res, strength)
73+
74+
cf = None
75+
76+
return res

0 commit comments

Comments
 (0)