@@ -68,8 +68,8 @@ def predict_celeba(sfs_net_model, dl, train_epoch_num = 0,
6868 # return average loss over dataset
6969 return tloss / len_dl
7070
71- def predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , dl , gan_real_dl , train_epoch_num = 0 ,
72- use_cuda = False , out_folder = None , wandb = None , suffix = 'Val' ):
71+ def predict_sfsnet_gan (sfs_net_model , albedo_gen_model , albedo_dis_model , dl , gan_real_dl , train_epoch_num = 0 ,
72+ use_cuda = False , out_folder = None , wandb = None , suffix = 'GAN Val' ):
7373
7474 # debugging flag to dump image
7575 fix_bix_dump = 0
@@ -193,6 +193,95 @@ def predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, dl, gan_re
193193 # return average loss over dataset
194194 return tloss / len_dl , aloss / len_dl , rloss / len_dl , ganloss / len_dl , disloss / len_dl
195195
196+ def predict_sfsnet (sfs_net_model , albedo_gen_model , dl , train_epoch_num = 0 ,
197+ use_cuda = False , out_folder = None , wandb = None , suffix = 'Val' ):
198+
199+ # debugging flag to dump image
200+ fix_bix_dump = 0
201+
202+ albedo_loss = nn .SmoothL1Loss () #nn.L1Loss()
203+ recon_loss = nn .SmoothL1Loss () #nn.L1Loss()
204+
205+ lamda_recon = 1
206+ lamda_albedo = 1
207+
208+ if use_cuda :
209+ albedo_loss = albedo_loss .cuda ()
210+ recon_loss = recon_loss .cuda ()
211+
212+ tloss = 0 # Total loss
213+ aloss = 0 # Albedo loss
214+ rloss = 0 # Reconstruction loss
215+
216+ for bix , data in enumerate (dl ):
217+ albedo , normal , mask , sh , face = data
218+ if use_cuda :
219+ albedo = albedo .cuda ()
220+ normal = normal .cuda ()
221+ mask = mask .cuda ()
222+ sh = sh .cuda ()
223+ face = face .cuda ()
224+
225+ # Apply Mask on input image
226+ # face = applyMask(face, mask)
227+ # predicted_face == reconstruction
228+ # predicted_normal, predicted_albedo, predicted_sh, predicted_shading, shading_residual, updated_shading, predicted_face = sfs_net_model(face)
229+
230+ # Apply Mask on input image
231+ # face = applyMask(face, mask)
232+ predicted_normal , albedo_features , predicted_sh , shading_residual = sfs_net_model (face )
233+
234+ fake_albedo = albedo_gen_model (albedo_features )
235+
236+ out_shading = get_shading (predicted_normal , predicted_sh )
237+ updated_shading = out_shading + shading_residual
238+ out_recon = reconstruct_image (updated_shading , fake_albedo )
239+
240+ # albedo recon loss
241+ current_albedo_loss = albedo_loss (fake_albedo , albedo )
242+ current_recon_loss = recon_loss (out_recon , face )
243+
244+ total_loss = lamda_albedo * current_albedo_loss + lamda_recon * current_recon_loss
245+
246+ # Logging for display and debugging purposes
247+ tloss += total_loss .item ()
248+ # nloss += current_normal_loss.item()
249+ aloss += current_albedo_loss .item ()
250+ # shloss += current_sh_loss.item()
251+ rloss += current_recon_loss .item ()
252+
253+ if bix == fix_bix_dump :
254+ # save predictions in log folder
255+ file_name = out_folder + suffix + '_' + str (train_epoch_num ) + '_' + str (fix_bix_dump )
256+ # log images
257+ # save_p_normal = get_normal_in_range(predicted_normal)
258+ save_gt_normal = get_normal_in_range (normal )
259+ save_p_normal = predicted_normal
260+
261+ wandb_log_images (wandb , save_p_normal , mask , suffix + ' Predicted Normal' , train_epoch_num , suffix + ' Predicted Normal' , path = file_name + '_predicted_normal.png' )
262+ wandb_log_images (wandb , fake_albedo , mask , suffix + ' Predicted Albedo' , train_epoch_num , suffix + ' Predicted Albedo' , path = file_name + '_predicted_albedo.png' )
263+ wandb_log_images (wandb , out_shading , mask , suffix + ' Predicted Shading' , train_epoch_num , suffix + ' Predicted Shading' , path = file_name + '_predicted_shading.png' , denormalize = False )
264+ wandb_log_images (wandb , shading_residual , mask , suffix + ' Predicted Shading Residual' , train_epoch_num , suffix + ' Predicted Shading Residual' , path = file_name + '_predicted_residual_shading.png' , denormalize = False )
265+ wandb_log_images (wandb , updated_shading , mask , suffix + ' Predicted Updated Shading' , train_epoch_num , suffix + ' Predicted Updated Shading' , path = file_name + '_predicted_updated_shading.png' , denormalize = False )
266+ wandb_log_images (wandb , out_recon , mask , suffix + ' Predicted face' , train_epoch_num , suffix + ' Predicted face' , path = file_name + '_predicted_face.png' , denormalize = False )
267+ wandb_log_images (wandb , face , mask , suffix + ' Ground Truth' , train_epoch_num , suffix + ' Ground Truth' , path = file_name + '_gt_face.png' )
268+ wandb_log_images (wandb , save_gt_normal , mask , suffix + ' Ground Truth Normal' , train_epoch_num , suffix + ' Ground Normal' , path = file_name + '_gt_normal.png' )
269+ wandb_log_images (wandb , albedo , mask , suffix + ' Ground Truth Albedo' , train_epoch_num , suffix + ' Ground Albedo' , path = file_name + '_gt_albedo.png' )
270+ # Get face with real SH
271+ real_sh_face = sfs_net_model .get_face (sh , predicted_normal , fake_albedo )
272+ wandb_log_images (wandb , real_sh_face , mask , 'Val Real SH Predicted Face' , train_epoch_num , 'Val Real SH Predicted Face' , path = file_name + '_real_sh_face.png' )
273+ syn_face = sfs_net_model .get_face (sh , normal , albedo )
274+ wandb_log_images (wandb , syn_face , mask , 'Val Real SH GT Face' , train_epoch_num , 'Val Real SH GT Face' , path = file_name + '_syn_gt_face.png' )
275+
276+ # TODO:
277+ # Dump SH as CSV or TXT file
278+
279+ len_dl = len (dl )
280+ wandb .log ({suffix + ' Total loss' : tloss / len_dl , suffix + 'Albedo loss' : aloss / len_dl , suffix + 'Recon loss' : rloss / len_dl }, step = train_epoch_num )
281+
282+ # return average loss over dataset
283+ return tloss / len_dl , aloss / len_dl , rloss / len_dl
284+
196285def gan_based_train (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_data , celeba_data = None , read_first = None ,
197286 batch_size = 10 , num_epochs = 10 , log_path = './results/metadata/' , use_cuda = False , wandb = None ,
198287 lr = 0.01 , wt_decay = 0.005 ):
@@ -342,7 +431,6 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
342431
343432 loss_d .backward ()
344433 d_optimizer .step ()
345- optimizer .step ()
346434
347435 # Logging for display and debugging purposes
348436 tloss += total_loss .item ()
@@ -384,7 +472,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
384472 wandb_log_images (wandb , real_sh_face , mask , 'Train Real SH Predicted Face' , epoch , 'Train Real SH Predicted Face' , path = file_name + '_real_sh_face.png' )
385473 wandb_log_images (wandb , syn_face , mask , 'Train Real SH GT Face' , epoch , 'Train Real SH GT Face' , path = file_name + '_syn_gt_face.png' )
386474
387- v_total , v_albedo , v_recon , v_gloss , v_dloss = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_val_dl , gan_real_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
475+ v_total , v_albedo , v_recon , v_gloss , v_dloss = predict_sfsnet_gan (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_val_dl , gan_real_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
388476 out_folder = out_syn_images_dir + '/val/' , wandb = wandb , suffix = 'GAN Val' )
389477 # wandb.log({log_prefix + 'Val Total loss': v_total, log_prefix + 'Val Albedo loss': v_albedo, log_prefix + 'Val Recon loss': v_recon})
390478
@@ -394,11 +482,9 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
394482 # Model saving
395483 torch .save (sfs_net_model .state_dict (), model_checkpoint_dir + 'sfs_net_model.pkl' )
396484 if epoch % 5 == 0 :
397- t_total , t_albedo , t_recon , t_gloss , t_dloss = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_test_dl , gan_real_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
485+ t_total , t_albedo , t_recon , t_gloss , t_dloss = predict_sfsnet_gan (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_test_dl , gan_real_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
398486 out_folder = out_syn_images_dir + '/test/' , wandb = wandb , suffix = 'GAN Test' )
399487
400- # wandb.log({log_prefix+'Test Total loss': t_total, log_prefix+'Test Albedo loss': t_albedo, log_prefix+'Test Recon loss': t_recon})
401-
402488 print ('Test-set results: Total Loss: {}, Albedo Loss: {}, Gan Loss: {}, Dis Loss: {} \n ' .format (t_total , t_albedo , t_gloss , t_dloss ))
403489
404490def train (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_data , celeba_data = None , read_first = None ,
@@ -423,15 +509,6 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
423509 syn_val_dl = DataLoader (val_dataset , batch_size = batch_size , shuffle = True )
424510 syn_test_dl = DataLoader (test_dataset , batch_size = batch_size , shuffle = True )
425511
426- gan_real_train_dataset , gan_real_val_dataset = get_sfsnet_dataset (syn_dir = syn_data + 'train/' , read_from_csv = syn_train_csv , read_celeba_csv = None , read_first = read_first , validation_split = 2 )
427- gan_real_test_dataset , _ = get_sfsnet_dataset (syn_dir = syn_data + 'test/' , read_from_csv = syn_test_csv , read_celeba_csv = None , read_first = read_first , validation_split = 0 )
428-
429- gan_real_train_dl = DataLoader (gan_real_train_dataset , batch_size = batch_size , shuffle = True )
430- train_real_gan_iter = iter (gan_real_train_dl )
431-
432- gan_real_val_dl = DataLoader (gan_real_val_dataset , batch_size = batch_size , shuffle = True )
433- gan_real_test_dl = DataLoader (gan_real_test_dataset , batch_size = batch_size , shuffle = True )
434-
435512 print ('Synthetic dataset: Train data: ' , len (syn_train_dl ), ' Val data: ' , len (syn_val_dl ), ' Test data: ' , len (syn_test_dl ))
436513
437514 model_checkpoint_dir = log_path + 'checkpoints/'
@@ -448,32 +525,24 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
448525 optimizer = torch .optim .Adam (model_parameters , lr = lr ) #, weight_decay=wt_decay)
449526 albedo_loss = nn .SmoothL1Loss () #nn.L1Loss()
450527 recon_loss = nn .SmoothL1Loss () #nn.L1Loss()
451- gan_loss = torch .nn .MSELoss ()
452-
453- # Collect and initialize gen-dis optimizers
454- g_optimizer = torch .optim .Adam (albedo_gen_model .parameters (), lr = lr )
455- d_optimizer = torch .optim .Adam (albedo_dis_model .parameters (), lr = lr )
456528
457529 if use_cuda :
458530 albedo_loss = albedo_loss .cuda ()
459531 recon_loss = recon_loss .cuda ()
460532
461533 lamda_recon = 1
462- lamda_albedo = 10
534+ lamda_albedo = 1
463535
464536 if use_cuda :
465537 albedo_loss = albedo_loss .cuda ()
466538 recon_loss = recon_loss .cuda ()
467- gan_loss = gan_loss .cuda ()
468539
469540 syn_train_len = len (syn_train_dl )
470541
471542 for epoch in range (1 , num_epochs + 1 ):
472543 tloss = 0 # Total loss
473544 aloss = 0 # Albedo loss
474545 rloss = 0 # Reconstruction loss
475- ganloss = 0 # Gan Loss
476- disloss = 0 # Dis Loss
477546
478547 for bix , data in enumerate (syn_train_dl ):
479548 albedo , normal , mask , sh , face = data
@@ -541,15 +610,15 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
541610 wandb_log_images (wandb , real_sh_face , mask , 'Train Real SH Predicted Face' , epoch , 'Train Real SH Predicted Face' , path = file_name + '_real_sh_face.png' )
542611 wandb_log_images (wandb , syn_face , mask , 'Train Real SH GT Face' , epoch , 'Train Real SH GT Face' , path = file_name + '_syn_gt_face.png' )
543612
544- v_total , v_albedo , v_recon , v_gloss , v_dloss = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_val_dl , gan_real_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
613+ v_total , v_albedo , v_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
545614 out_folder = out_syn_images_dir + '/val/' , wandb = wandb )
546615 # wandb.log({log_prefix + 'Val Total loss': v_total, log_prefix + 'Val Albedo loss': v_albedo, log_prefix + 'Val Recon loss': v_recon})
547616 print ('Val set results: Total Loss: {}, Albedo Loss: {}, Recon Loss: {}' .format (v_total , v_albedo , v_recon ))
548617
549618 # Model saving
550619 torch .save (sfs_net_model .state_dict (), model_checkpoint_dir + 'sfs_net_model.pkl' )
551620 if epoch % 5 == 0 :
552- t_total , t_albedo , t_recon , t_gloss , t_dloss = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_test_dl , gan_real_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
621+ t_total , t_albedo , t_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
553622 out_folder = out_syn_images_dir + '/test/' , wandb = wandb , suffix = 'Test' )
554623
555624 # wandb.log({log_prefix+'Test Total loss': t_total, log_prefix+'Test Albedo loss': t_albedo, log_prefix+'Test Recon loss': t_recon})
0 commit comments