@@ -298,7 +298,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
298298
299299 # Load Synthetic dataset
300300 train_dataset , val_dataset = get_sfsnet_dataset (syn_dir = syn_data + 'train/' , read_from_csv = syn_train_csv , read_celeba_csv = celeba_train_csv , read_first = read_first , validation_split = 2 )
301- test_dataset , _ = get_sfsnet_dataset (syn_dir = syn_data + 'test/' , read_from_csv = syn_test_csv , read_celeba_csv = celeba_test_csv , read_first = 100 , validation_split = 0 )
301+ test_dataset , _ = get_sfsnet_dataset (syn_dir = syn_data + 'test/' , read_from_csv = None , read_celeba_csv = celeba_test_csv , read_first = 100 , validation_split = 0 )
302302
303303 syn_train_dl = DataLoader (train_dataset , batch_size = batch_size , shuffle = True )
304304 syn_val_dl = DataLoader (val_dataset , batch_size = batch_size , shuffle = True )
@@ -324,16 +324,6 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
324324 os .system ('mkdir -p {}' .format (out_syn_images_dir + 'val/' ))
325325 os .system ('mkdir -p {}' .format (out_syn_images_dir + 'test/' ))
326326
327- # Create Generator and Discriminator
328-
329- if use_cuda :
330- albedo_gen_model = albedo_gen_model .cuda ()
331- albedo_dis_model = albedo_dis_model .cuda ()
332-
333- # Init gen and disc
334- albedo_gen_model .apply (weights_init )
335- albedo_dis_model .apply (weights_init )
336-
337327 # Collect model parameters
338328 model_parameters = sfs_net_model .parameters ()
339329 optimizer = torch .optim .Adam (model_parameters , lr = lr ) #, weight_decay=wt_decay)
@@ -345,10 +335,6 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
345335 g_optimizer = torch .optim .Adam (albedo_gen_model .parameters (), lr = lr )
346336 d_optimizer = torch .optim .Adam (albedo_dis_model .parameters (), lr = lr )
347337
348- if use_cuda :
349- albedo_loss = albedo_loss .cuda ()
350- recon_loss = recon_loss .cuda ()
351-
352338 lamda_recon = 1
353339 lamda_albedo = 10
354340
@@ -377,9 +363,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
377363
378364 # Apply Mask on input image
379365 # face = applyMask(face, mask)
380- predicted_normal , albedo_features , predicted_sh , shading_residual = sfs_net_model (face )
381- optimizer .zero_grad ()
382- # GAN Training
366+ # GAN Training
383367 valid = torch .ones (albedo .shape [0 ], requires_grad = False )
384368 fake = torch .zeros (albedo .shape [0 ], requires_grad = False )
385369
@@ -400,6 +384,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
400384 real_sample = real_sample .cuda ()
401385
402386 # GAN loss
387+ predicted_normal , albedo_features , predicted_sh , shading_residual = sfs_net_model (face )
403388 fake_albedo = albedo_gen_model (albedo_features )
404389 pred_fake = albedo_dis_model (fake_albedo )
405390 # print(pred_fake.shape, valid.shape)
@@ -610,15 +595,15 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
610595 wandb_log_images (wandb , real_sh_face , mask , 'Train Real SH Predicted Face' , epoch , 'Train Real SH Predicted Face' , path = file_name + '_real_sh_face.png' )
611596 wandb_log_images (wandb , syn_face , mask , 'Train Real SH GT Face' , epoch , 'Train Real SH GT Face' , path = file_name + '_syn_gt_face.png' )
612597
613- v_total , v_albedo , v_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
598+ v_total , v_albedo , v_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , syn_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
614599 out_folder = out_syn_images_dir + '/val/' , wandb = wandb )
615600 # wandb.log({log_prefix + 'Val Total loss': v_total, log_prefix + 'Val Albedo loss': v_albedo, log_prefix + 'Val Recon loss': v_recon})
616601 print ('Val set results: Total Loss: {}, Albedo Loss: {}, Recon Loss: {}' .format (v_total , v_albedo , v_recon ))
617602
618603 # Model saving
619604 torch .save (sfs_net_model .state_dict (), model_checkpoint_dir + 'sfs_net_model.pkl' )
620605 if epoch % 5 == 0 :
621- t_total , t_albedo , t_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , albedo_dis_model , syn_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
606+ t_total , t_albedo , t_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , syn_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
622607 out_folder = out_syn_images_dir + '/test/' , wandb = wandb , suffix = 'Test' )
623608
624609 # wandb.log({log_prefix+'Test Total loss': t_total, log_prefix+'Test Albedo loss': t_albedo, log_prefix+'Test Recon loss': t_recon})
@@ -759,7 +744,7 @@ def train_with_shading_loss(sfs_net_model, syn_data, celeba_data=None, read_firs
759744 wandb_log_images (wandb , real_sh_face , mask , 'Train Real SH Predicted Face' , epoch , 'Train Real SH Predicted Face' , path = file_name + '_real_sh_face.png' )
760745 wandb_log_images (wandb , syn_face , mask , 'Train Real SH GT Face' , epoch , 'Train Real SH GT Face' , path = file_name + '_syn_gt_face.png' )
761746
762- v_total , v_albedo , v_recon = predict_sfsnet (sfs_net_model , syn_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
747+ v_total , v_albedo , v_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , syn_val_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
763748 out_folder = out_syn_images_dir + '/val/' , wandb = wandb )
764749 wandb .log ({log_prefix + 'Val Total loss' : v_total , log_prefix + 'Val Albedo loss' : v_albedo , log_prefix + 'Val Recon loss' : v_recon })
765750
@@ -769,7 +754,7 @@ def train_with_shading_loss(sfs_net_model, syn_data, celeba_data=None, read_firs
769754 # Model saving
770755 torch .save (sfs_net_model .state_dict (), model_checkpoint_dir + 'sfs_net_model.pkl' )
771756 if epoch % 5 == 0 :
772- t_total , t_albedo , t_recon = predict_sfsnet (sfs_net_model , syn_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
757+ t_total , t_albedo , t_recon = predict_sfsnet (sfs_net_model , albedo_gen_model , syn_test_dl , train_epoch_num = epoch , use_cuda = use_cuda ,
773758 out_folder = out_syn_images_dir + '/test/' , wandb = wandb , suffix = 'Test' )
774759
775760 wandb .log ({log_prefix + 'Test Total loss' : t_total , log_prefix + 'Test Albedo loss' : t_albedo , log_prefix + 'Test Recon loss' : t_recon })
0 commit comments