1+ import os
2+ import torch
3+ import numpy as np
4+ import warnings
5+
6+ pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
7+
8+ class CodeFormerRestoration ():
9+ def __init__ (self ) -> None :
10+ pass
11+
12+ def process (self , image , strength , device , seed = None , fidelity = 0.75 ):
13+ if seed is not None :
14+ print (f'>> CodeFormer - Restoring Faces for image seed:{ seed } ' )
15+ with warnings .catch_warnings ():
16+ warnings .filterwarnings ('ignore' , category = DeprecationWarning )
17+ warnings .filterwarnings ('ignore' , category = UserWarning )
18+
19+ from basicsr .utils .download_util import load_file_from_url
20+ from basicsr .utils import img2tensor , tensor2img
21+ from facexlib .utils .face_restoration_helper import FaceRestoreHelper
22+ from ldm .restoration .codeformer .codeformer_arch import CodeFormer
23+ from torchvision .transforms .functional import normalize
24+ from PIL import Image
25+
26+ cf_class = CodeFormer
27+
28+ cf = cf_class (dim_embd = 512 , codebook_size = 1024 , n_head = 8 , n_layers = 9 , connect_list = ['32' , '64' , '128' , '256' ]).to (device )
29+
30+ checkpoint_path = load_file_from_url (url = pretrained_model_url , model_dir = os .path .abspath ('ldm/restoration/codeformer/weights' ), progress = True )
31+ checkpoint = torch .load (checkpoint_path )['params_ema' ]
32+ cf .load_state_dict (checkpoint )
33+ cf .eval ()
34+
35+ image = image .convert ('RGB' )
36+
37+ face_helper = FaceRestoreHelper (upscale_factor = 1 , use_parse = True , device = device )
38+ face_helper .clean_all ()
39+ face_helper .read_image (np .array (image , dtype = np .uint8 ))
40+ face_helper .get_face_landmarks_5 (resize = 640 , eye_dist_threshold = 5 )
41+ face_helper .align_warp_face ()
42+
43+ for idx , cropped_face in enumerate (face_helper .cropped_faces ):
44+ cropped_face_t = img2tensor (cropped_face / 255. , bgr2rgb = True , float32 = True )
45+ normalize (cropped_face_t , (0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ), inplace = True )
46+ cropped_face_t = cropped_face_t .unsqueeze (0 ).to (device )
47+
48+ try :
49+ with torch .no_grad ():
50+ output = cf (cropped_face_t , w = fidelity , adain = True )[0 ]
51+ restored_face = tensor2img (output .squeeze (0 ), rgb2bgr = True , min_max = (- 1 , 1 ))
52+ del output
53+ torch .cuda .empty_cache ()
54+ except RuntimeError as error :
55+ print (f'\t Failed inference for CodeFormer: { error } .' )
56+ restored_face = cropped_face
57+
58+ restored_face = restored_face .astype ('uint8' )
59+ face_helper .add_restored_face (restored_face )
60+
61+
62+ face_helper .get_inverse_affine (None )
63+
64+ restored_img = face_helper .paste_faces_to_input_image ()
65+
66+ res = Image .fromarray (restored_img )
67+
68+ if strength < 1.0 :
69+ # Resize the image to the new image if the sizes have changed
70+ if restored_img .size != image .size :
71+ image = image .resize (res .size )
72+ res = Image .blend (image , res , strength )
73+
74+ cf = None
75+
76+ return res
0 commit comments