Skip to content

Commit 1382453

Browse files
committed
Shading Residual: Add Albedo residual 2 and read_csv optional data in data loading
1 parent f182840 commit 1382453

2 files changed

Lines changed: 88 additions & 37 deletions

File tree

3_Shading_Residual/data_loading.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -119,25 +119,7 @@ def get_sfsnet_dataset(syn_dir=None, read_from_csv=None, read_celeba_csv=None, r
119119
face = []
120120
depth = []
121121

122-
if read_from_csv is None:
123-
for img in sorted(glob.glob(syn_dir + '*/*_albedo_*')):
124-
albedo.append(img)
125-
126-
for img in sorted(glob.glob(syn_dir + '*/*_face_*')):
127-
face.append(img)
128-
129-
for img in sorted(glob.glob(syn_dir + '*/*_normal_*')):
130-
normal.append(img)
131-
132-
for img in sorted(glob.glob(syn_dir + '*/*_depth_*')):
133-
depth.append(img)
134-
135-
for img in sorted(glob.glob(syn_dir + '*/*_mask_*')):
136-
mask.append(img)
137-
138-
for img in sorted(glob.glob(syn_dir + '*/*_light_*.txt')):
139-
sh.append(img)
140-
else:
122+
if read_from_csv is not None:
141123
df = pd.read_csv(read_from_csv)
142124
if read_first is not None and len(df) > read_first:
143125
df = df.sample(read_first)
@@ -155,17 +137,17 @@ def get_sfsnet_dataset(syn_dir=None, read_from_csv=None, read_celeba_csv=None, r
155137
for _, v in name_to_list.items():
156138
v[:] = [syn_dir + el for el in v]
157139

158-
# Merge Synthesized Celeba dataset for Psedo-Supervised training
159-
if read_celeba_csv is not None:
160-
df = pd.read_csv(read_celeba_csv)
161-
if read_first is not None and len(df) > read_first:
162-
df = df.sample(read_first)
163-
albedo += list(df['albedo'])
164-
face += list(df['face'])
165-
normal += list(df['normal'])
166-
depth += list(df['depth'])
167-
mask += list(df['mask'])
168-
sh += list(df['light'])
140+
# Merge Synthesized Celeba dataset for Psedo-Supervised training
141+
if read_celeba_csv is not None:
142+
df = pd.read_csv(read_celeba_csv)
143+
if read_first is not None and len(df) > read_first:
144+
df = df.sample(read_first)
145+
albedo += list(df['albedo'])
146+
face += list(df['face'])
147+
normal += list(df['normal'])
148+
depth += list(df['depth'])
149+
mask += list(df['mask'])
150+
sh += list(df['light'])
169151

170152
assert(len(albedo) == len(face) == len(normal) == len(depth) == len(mask) == len(sh))
171153
dataset_size = len(albedo)

3_Shading_Residual/models.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def __init__(self):
278278
self.normal_gen_model = NormalGenerationNet()
279279
self.albedo_residual_model = AlbedoResidualBlock()
280280
self.albedo_gen_model = AlbedoGenerationNet()
281+
self.albedo_residual_model_2 = AlbedoResidualBlock()
281282
self.light_estimator_model = LightEstimator()
282283
self.shading_residual_model = ShadingResidualEstimator()
283284

@@ -298,22 +299,25 @@ def forward(self, face):
298299

299300
# 3 a. Generate Normal
300301
predicted_normal = self.normal_gen_model(out_normal_features)
301-
# 3 b. Generate Albedo
302-
predicted_albedo = self.albedo_gen_model(out_albedo_features)
303-
# 3 c. Estimate lighting
302+
# 3 b. Estimate lighting
304303
# First, concat conv, normal and albedo features over channels dimension
305304
all_features = torch.cat((out_features, out_normal_features, out_albedo_features), dim=1)
306305
# Predict SH
307306
predicted_sh = self.light_estimator_model(all_features)
308307

309-
# 4. Generate shading
308+
# 4. Collect new albedo features
309+
out_albedo_features_2 = self.albedo_residual_model_2(out_features)
310+
new_features = torch.cat((out_features, out_normal_features, out_albedo_features_2), dim=1)
311+
predicted_albedo = self.albedo_gen_model(out_albedo_features_2)
312+
313+
# 5. Generate shading
310314
out_shading = get_shading(predicted_normal, predicted_sh)
311315

312-
# 5. Get Shading Residual
313-
shading_residual = self.shading_residual_model(all_features)
316+
# 6. Get Shading Residual
317+
shading_residual = self.shading_residual_model(new_features)
314318
updated_shading = out_shading + shading_residual
315319

316-
# 6. Reconstruction of image
320+
# 7. Reconstruction of image
317321
out_recon = reconstruct_image(updated_shading, predicted_albedo)
318322

319323
return predicted_normal, predicted_albedo, predicted_sh, out_shading, shading_residual, updated_shading, out_recon
@@ -509,6 +513,71 @@ def load_model_from_pretrained(src_model, dst_model):
509513
dst_model['albedo_gen_model.conv2.1.running_var'] = src_model['aconv2.conv.1.running_var']
510514
dst_model['albedo_gen_model.conv3.weight'] = src_model['aout.weight']
511515
dst_model['albedo_gen_model.conv3.bias'] = src_model['aout.bias']
516+
dst_model['albedo_residual_model_2.block1.res.0.weight'] = src_model['ares1.res.0.weight']
517+
dst_model['albedo_residual_model_2.block1.res.0.bias'] = src_model['ares1.res.0.bias']
518+
dst_model['albedo_residual_model_2.block1.res.0.running_mean'] = src_model['ares1.res.0.running_mean']
519+
dst_model['albedo_residual_model_2.block1.res.0.running_var'] = src_model['ares1.res.0.running_var']
520+
dst_model['albedo_residual_model_2.block1.res.2.weight'] = src_model['ares1.res.2.weight']
521+
dst_model['albedo_residual_model_2.block1.res.2.bias'] = src_model['ares1.res.2.bias']
522+
dst_model['albedo_residual_model_2.block1.res.3.weight'] = src_model['ares1.res.3.weight']
523+
dst_model['albedo_residual_model_2.block1.res.3.bias'] = src_model['ares1.res.3.bias']
524+
dst_model['albedo_residual_model_2.block1.res.3.running_mean'] = src_model['ares1.res.3.running_mean']
525+
dst_model['albedo_residual_model_2.block1.res.3.running_var'] = src_model['ares1.res.3.running_var']
526+
dst_model['albedo_residual_model_2.block1.res.5.weight'] = src_model['ares1.res.5.weight']
527+
dst_model['albedo_residual_model_2.block1.res.5.bias'] = src_model['ares1.res.5.bias']
528+
dst_model['albedo_residual_model_2.block2.res.0.weight'] = src_model['ares2.res.0.weight']
529+
dst_model['albedo_residual_model_2.block2.res.0.bias'] = src_model['ares2.res.0.bias']
530+
dst_model['albedo_residual_model_2.block2.res.0.running_mean'] = src_model['ares2.res.0.running_mean']
531+
dst_model['albedo_residual_model_2.block2.res.0.running_var'] = src_model['ares2.res.0.running_var']
532+
dst_model['albedo_residual_model_2.block2.res.2.weight'] = src_model['ares2.res.2.weight']
533+
dst_model['albedo_residual_model_2.block2.res.2.bias'] = src_model['ares2.res.2.bias']
534+
dst_model['albedo_residual_model_2.block2.res.3.weight'] = src_model['ares2.res.3.weight']
535+
dst_model['albedo_residual_model_2.block2.res.3.bias'] = src_model['ares2.res.3.bias']
536+
dst_model['albedo_residual_model_2.block2.res.3.running_mean'] = src_model['ares2.res.3.running_mean']
537+
dst_model['albedo_residual_model_2.block2.res.3.running_var'] = src_model['ares2.res.3.running_var']
538+
dst_model['albedo_residual_model_2.block2.res.5.weight'] = src_model['ares2.res.5.weight']
539+
dst_model['albedo_residual_model_2.block2.res.5.bias'] = src_model['ares2.res.5.bias']
540+
dst_model['albedo_residual_model_2.block3.res.0.weight'] = src_model['ares3.res.0.weight']
541+
dst_model['albedo_residual_model_2.block3.res.0.bias'] = src_model['ares3.res.0.bias']
542+
dst_model['albedo_residual_model_2.block3.res.0.running_mean'] = src_model['ares3.res.0.running_mean']
543+
dst_model['albedo_residual_model_2.block3.res.0.running_var'] = src_model['ares3.res.0.running_var']
544+
dst_model['albedo_residual_model_2.block3.res.2.weight'] = src_model['ares3.res.2.weight']
545+
dst_model['albedo_residual_model_2.block3.res.2.bias'] = src_model['ares3.res.2.bias']
546+
dst_model['albedo_residual_model_2.block3.res.3.weight'] = src_model['ares3.res.3.weight']
547+
dst_model['albedo_residual_model_2.block3.res.3.bias'] = src_model['ares3.res.3.bias']
548+
dst_model['albedo_residual_model_2.block3.res.3.running_mean'] = src_model['ares3.res.3.running_mean']
549+
dst_model['albedo_residual_model_2.block3.res.3.running_var'] = src_model['ares3.res.3.running_var']
550+
dst_model['albedo_residual_model_2.block3.res.5.weight'] = src_model['ares3.res.5.weight']
551+
dst_model['albedo_residual_model_2.block3.res.5.bias'] = src_model['ares3.res.5.bias']
552+
dst_model['albedo_residual_model_2.block4.res.0.weight'] = src_model['ares4.res.0.weight']
553+
dst_model['albedo_residual_model_2.block4.res.0.bias'] = src_model['ares4.res.0.bias']
554+
dst_model['albedo_residual_model_2.block4.res.0.running_mean'] = src_model['ares4.res.0.running_mean']
555+
dst_model['albedo_residual_model_2.block4.res.0.running_var'] = src_model['ares4.res.0.running_var']
556+
dst_model['albedo_residual_model_2.block4.res.2.weight'] = src_model['ares4.res.2.weight']
557+
dst_model['albedo_residual_model_2.block4.res.2.bias'] = src_model['ares4.res.2.bias']
558+
dst_model['albedo_residual_model_2.block4.res.3.weight'] = src_model['ares4.res.3.weight']
559+
dst_model['albedo_residual_model_2.block4.res.3.bias'] = src_model['ares4.res.3.bias']
560+
dst_model['albedo_residual_model_2.block4.res.3.running_mean'] = src_model['ares4.res.3.running_mean']
561+
dst_model['albedo_residual_model_2.block4.res.3.running_var'] = src_model['ares4.res.3.running_var']
562+
dst_model['albedo_residual_model_2.block4.res.5.weight'] = src_model['ares4.res.5.weight']
563+
dst_model['albedo_residual_model_2.block4.res.5.bias'] = src_model['ares4.res.5.bias']
564+
dst_model['albedo_residual_model_2.block5.res.0.weight'] = src_model['ares5.res.0.weight']
565+
dst_model['albedo_residual_model_2.block5.res.0.bias'] = src_model['ares5.res.0.bias']
566+
dst_model['albedo_residual_model_2.block5.res.0.running_mean'] = src_model['ares5.res.0.running_mean']
567+
dst_model['albedo_residual_model_2.block5.res.0.running_var'] = src_model['ares5.res.0.running_var']
568+
dst_model['albedo_residual_model_2.block5.res.2.weight'] = src_model['ares5.res.2.weight']
569+
dst_model['albedo_residual_model_2.block5.res.2.bias'] = src_model['ares5.res.2.bias']
570+
dst_model['albedo_residual_model_2.block5.res.3.weight'] = src_model['ares5.res.3.weight']
571+
dst_model['albedo_residual_model_2.block5.res.3.bias'] = src_model['ares5.res.3.bias']
572+
dst_model['albedo_residual_model_2.block5.res.3.running_mean'] = src_model['ares5.res.3.running_mean']
573+
dst_model['albedo_residual_model_2.block5.res.3.running_var'] = src_model['ares5.res.3.running_var']
574+
dst_model['albedo_residual_model_2.block5.res.5.weight'] = src_model['ares5.res.5.weight']
575+
dst_model['albedo_residual_model_2.block5.res.5.bias'] = src_model['ares5.res.5.bias']
576+
dst_model['albedo_residual_model_2.bn1.weight'] = src_model['areso.0.weight']
577+
dst_model['albedo_residual_model_2.bn1.bias'] = src_model['areso.0.bias']
578+
dst_model['albedo_residual_model_2.bn1.running_mean'] = src_model['areso.0.running_mean']
579+
dst_model['albedo_residual_model_2.bn1.running_var'] = src_model['areso.0.running_var']
580+
512581
dst_model['light_estimator_model.conv1.0.weight'] = src_model['lconv.conv.0.weight']
513582
dst_model['light_estimator_model.conv1.0.bias'] = src_model['lconv.conv.0.bias']
514583
dst_model['light_estimator_model.conv1.1.weight'] = src_model['lconv.conv.1.weight']

0 commit comments

Comments
 (0)