@@ -278,6 +278,7 @@ def __init__(self):
278278 self .normal_gen_model = NormalGenerationNet ()
279279 self .albedo_residual_model = AlbedoResidualBlock ()
280280 self .albedo_gen_model = AlbedoGenerationNet ()
281+ self .albedo_residual_model_2 = AlbedoResidualBlock ()
281282 self .light_estimator_model = LightEstimator ()
282283 self .shading_residual_model = ShadingResidualEstimator ()
283284
@@ -298,22 +299,25 @@ def forward(self, face):
298299
299300 # 3 a. Generate Normal
300301 predicted_normal = self .normal_gen_model (out_normal_features )
301- # 3 b. Generate Albedo
302- predicted_albedo = self .albedo_gen_model (out_albedo_features )
303- # 3 c. Estimate lighting
302+ # 3 b. Estimate lighting
304303 # First, concat conv, normal and albedo features over channels dimension
305304 all_features = torch .cat ((out_features , out_normal_features , out_albedo_features ), dim = 1 )
306305 # Predict SH
307306 predicted_sh = self .light_estimator_model (all_features )
308307
309- # 4. Generate shading
308+ # 4. Collect new albedo features
309+ out_albedo_features_2 = self .albedo_residual_model_2 (out_features )
310+ new_features = torch .cat ((out_features , out_normal_features , out_albedo_features_2 ), dim = 1 )
311+ predicted_albedo = self .albedo_gen_model (out_albedo_features_2 )
312+
313+ # 5. Generate shading
310314 out_shading = get_shading (predicted_normal , predicted_sh )
311315
312- # 5 . Get Shading Residual
313- shading_residual = self .shading_residual_model (all_features )
316+ # 6 . Get Shading Residual
317+ shading_residual = self .shading_residual_model (new_features )
314318 updated_shading = out_shading + shading_residual
315319
316- # 6 . Reconstruction of image
320+ # 7 . Reconstruction of image
317321 out_recon = reconstruct_image (updated_shading , predicted_albedo )
318322
319323 return predicted_normal , predicted_albedo , predicted_sh , out_shading , shading_residual , updated_shading , out_recon
@@ -509,6 +513,71 @@ def load_model_from_pretrained(src_model, dst_model):
509513 dst_model ['albedo_gen_model.conv2.1.running_var' ] = src_model ['aconv2.conv.1.running_var' ]
510514 dst_model ['albedo_gen_model.conv3.weight' ] = src_model ['aout.weight' ]
511515 dst_model ['albedo_gen_model.conv3.bias' ] = src_model ['aout.bias' ]
516+ dst_model ['albedo_residual_model_2.block1.res.0.weight' ] = src_model ['ares1.res.0.weight' ]
517+ dst_model ['albedo_residual_model_2.block1.res.0.bias' ] = src_model ['ares1.res.0.bias' ]
518+ dst_model ['albedo_residual_model_2.block1.res.0.running_mean' ] = src_model ['ares1.res.0.running_mean' ]
519+ dst_model ['albedo_residual_model_2.block1.res.0.running_var' ] = src_model ['ares1.res.0.running_var' ]
520+ dst_model ['albedo_residual_model_2.block1.res.2.weight' ] = src_model ['ares1.res.2.weight' ]
521+ dst_model ['albedo_residual_model_2.block1.res.2.bias' ] = src_model ['ares1.res.2.bias' ]
522+ dst_model ['albedo_residual_model_2.block1.res.3.weight' ] = src_model ['ares1.res.3.weight' ]
523+ dst_model ['albedo_residual_model_2.block1.res.3.bias' ] = src_model ['ares1.res.3.bias' ]
524+ dst_model ['albedo_residual_model_2.block1.res.3.running_mean' ] = src_model ['ares1.res.3.running_mean' ]
525+ dst_model ['albedo_residual_model_2.block1.res.3.running_var' ] = src_model ['ares1.res.3.running_var' ]
526+ dst_model ['albedo_residual_model_2.block1.res.5.weight' ] = src_model ['ares1.res.5.weight' ]
527+ dst_model ['albedo_residual_model_2.block1.res.5.bias' ] = src_model ['ares1.res.5.bias' ]
528+ dst_model ['albedo_residual_model_2.block2.res.0.weight' ] = src_model ['ares2.res.0.weight' ]
529+ dst_model ['albedo_residual_model_2.block2.res.0.bias' ] = src_model ['ares2.res.0.bias' ]
530+ dst_model ['albedo_residual_model_2.block2.res.0.running_mean' ] = src_model ['ares2.res.0.running_mean' ]
531+ dst_model ['albedo_residual_model_2.block2.res.0.running_var' ] = src_model ['ares2.res.0.running_var' ]
532+ dst_model ['albedo_residual_model_2.block2.res.2.weight' ] = src_model ['ares2.res.2.weight' ]
533+ dst_model ['albedo_residual_model_2.block2.res.2.bias' ] = src_model ['ares2.res.2.bias' ]
534+ dst_model ['albedo_residual_model_2.block2.res.3.weight' ] = src_model ['ares2.res.3.weight' ]
535+ dst_model ['albedo_residual_model_2.block2.res.3.bias' ] = src_model ['ares2.res.3.bias' ]
536+ dst_model ['albedo_residual_model_2.block2.res.3.running_mean' ] = src_model ['ares2.res.3.running_mean' ]
537+ dst_model ['albedo_residual_model_2.block2.res.3.running_var' ] = src_model ['ares2.res.3.running_var' ]
538+ dst_model ['albedo_residual_model_2.block2.res.5.weight' ] = src_model ['ares2.res.5.weight' ]
539+ dst_model ['albedo_residual_model_2.block2.res.5.bias' ] = src_model ['ares2.res.5.bias' ]
540+ dst_model ['albedo_residual_model_2.block3.res.0.weight' ] = src_model ['ares3.res.0.weight' ]
541+ dst_model ['albedo_residual_model_2.block3.res.0.bias' ] = src_model ['ares3.res.0.bias' ]
542+ dst_model ['albedo_residual_model_2.block3.res.0.running_mean' ] = src_model ['ares3.res.0.running_mean' ]
543+ dst_model ['albedo_residual_model_2.block3.res.0.running_var' ] = src_model ['ares3.res.0.running_var' ]
544+ dst_model ['albedo_residual_model_2.block3.res.2.weight' ] = src_model ['ares3.res.2.weight' ]
545+ dst_model ['albedo_residual_model_2.block3.res.2.bias' ] = src_model ['ares3.res.2.bias' ]
546+ dst_model ['albedo_residual_model_2.block3.res.3.weight' ] = src_model ['ares3.res.3.weight' ]
547+ dst_model ['albedo_residual_model_2.block3.res.3.bias' ] = src_model ['ares3.res.3.bias' ]
548+ dst_model ['albedo_residual_model_2.block3.res.3.running_mean' ] = src_model ['ares3.res.3.running_mean' ]
549+ dst_model ['albedo_residual_model_2.block3.res.3.running_var' ] = src_model ['ares3.res.3.running_var' ]
550+ dst_model ['albedo_residual_model_2.block3.res.5.weight' ] = src_model ['ares3.res.5.weight' ]
551+ dst_model ['albedo_residual_model_2.block3.res.5.bias' ] = src_model ['ares3.res.5.bias' ]
552+ dst_model ['albedo_residual_model_2.block4.res.0.weight' ] = src_model ['ares4.res.0.weight' ]
553+ dst_model ['albedo_residual_model_2.block4.res.0.bias' ] = src_model ['ares4.res.0.bias' ]
554+ dst_model ['albedo_residual_model_2.block4.res.0.running_mean' ] = src_model ['ares4.res.0.running_mean' ]
555+ dst_model ['albedo_residual_model_2.block4.res.0.running_var' ] = src_model ['ares4.res.0.running_var' ]
556+ dst_model ['albedo_residual_model_2.block4.res.2.weight' ] = src_model ['ares4.res.2.weight' ]
557+ dst_model ['albedo_residual_model_2.block4.res.2.bias' ] = src_model ['ares4.res.2.bias' ]
558+ dst_model ['albedo_residual_model_2.block4.res.3.weight' ] = src_model ['ares4.res.3.weight' ]
559+ dst_model ['albedo_residual_model_2.block4.res.3.bias' ] = src_model ['ares4.res.3.bias' ]
560+ dst_model ['albedo_residual_model_2.block4.res.3.running_mean' ] = src_model ['ares4.res.3.running_mean' ]
561+ dst_model ['albedo_residual_model_2.block4.res.3.running_var' ] = src_model ['ares4.res.3.running_var' ]
562+ dst_model ['albedo_residual_model_2.block4.res.5.weight' ] = src_model ['ares4.res.5.weight' ]
563+ dst_model ['albedo_residual_model_2.block4.res.5.bias' ] = src_model ['ares4.res.5.bias' ]
564+ dst_model ['albedo_residual_model_2.block5.res.0.weight' ] = src_model ['ares5.res.0.weight' ]
565+ dst_model ['albedo_residual_model_2.block5.res.0.bias' ] = src_model ['ares5.res.0.bias' ]
566+ dst_model ['albedo_residual_model_2.block5.res.0.running_mean' ] = src_model ['ares5.res.0.running_mean' ]
567+ dst_model ['albedo_residual_model_2.block5.res.0.running_var' ] = src_model ['ares5.res.0.running_var' ]
568+ dst_model ['albedo_residual_model_2.block5.res.2.weight' ] = src_model ['ares5.res.2.weight' ]
569+ dst_model ['albedo_residual_model_2.block5.res.2.bias' ] = src_model ['ares5.res.2.bias' ]
570+ dst_model ['albedo_residual_model_2.block5.res.3.weight' ] = src_model ['ares5.res.3.weight' ]
571+ dst_model ['albedo_residual_model_2.block5.res.3.bias' ] = src_model ['ares5.res.3.bias' ]
572+ dst_model ['albedo_residual_model_2.block5.res.3.running_mean' ] = src_model ['ares5.res.3.running_mean' ]
573+ dst_model ['albedo_residual_model_2.block5.res.3.running_var' ] = src_model ['ares5.res.3.running_var' ]
574+ dst_model ['albedo_residual_model_2.block5.res.5.weight' ] = src_model ['ares5.res.5.weight' ]
575+ dst_model ['albedo_residual_model_2.block5.res.5.bias' ] = src_model ['ares5.res.5.bias' ]
576+ dst_model ['albedo_residual_model_2.bn1.weight' ] = src_model ['areso.0.weight' ]
577+ dst_model ['albedo_residual_model_2.bn1.bias' ] = src_model ['areso.0.bias' ]
578+ dst_model ['albedo_residual_model_2.bn1.running_mean' ] = src_model ['areso.0.running_mean' ]
579+ dst_model ['albedo_residual_model_2.bn1.running_var' ] = src_model ['areso.0.running_var' ]
580+
512581 dst_model ['light_estimator_model.conv1.0.weight' ] = src_model ['lconv.conv.0.weight' ]
513582 dst_model ['light_estimator_model.conv1.0.bias' ] = src_model ['lconv.conv.0.bias' ]
514583 dst_model ['light_estimator_model.conv1.1.weight' ] = src_model ['lconv.conv.1.weight' ]
0 commit comments