Skip to content

Commit bc423b7

Browse files
committed
Fix shading training
1 parent 0c59c9b commit bc423b7

1 file changed

Lines changed: 96 additions & 27 deletions

File tree

  • 6_Shading_Resdiual_GAN_Separate_Train

6_Shading_Resdiual_GAN_Separate_Train/train.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def predict_celeba(sfs_net_model, dl, train_epoch_num = 0,
6868
# return average loss over dataset
6969
return tloss / len_dl
7070

71-
def predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, dl, gan_real_dl, train_epoch_num = 0,
72-
use_cuda = False, out_folder = None, wandb = None, suffix = 'Val'):
71+
def predict_sfsnet_gan(sfs_net_model, albedo_gen_model, albedo_dis_model, dl, gan_real_dl, train_epoch_num = 0,
72+
use_cuda = False, out_folder = None, wandb = None, suffix = 'GAN Val'):
7373

7474
# debugging flag to dump image
7575
fix_bix_dump = 0
@@ -193,6 +193,95 @@ def predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, dl, gan_re
193193
# return average loss over dataset
194194
return tloss / len_dl, aloss / len_dl, rloss / len_dl, ganloss/len_dl, disloss / len_dl
195195

196+
def predict_sfsnet(sfs_net_model, albedo_gen_model, dl, train_epoch_num = 0,
197+
use_cuda = False, out_folder = None, wandb = None, suffix = 'Val'):
198+
199+
# debugging flag to dump image
200+
fix_bix_dump = 0
201+
202+
albedo_loss = nn.SmoothL1Loss() #nn.L1Loss()
203+
recon_loss = nn.SmoothL1Loss() #nn.L1Loss()
204+
205+
lamda_recon = 1
206+
lamda_albedo = 1
207+
208+
if use_cuda:
209+
albedo_loss = albedo_loss.cuda()
210+
recon_loss = recon_loss.cuda()
211+
212+
tloss = 0 # Total loss
213+
aloss = 0 # Albedo loss
214+
rloss = 0 # Reconstruction loss
215+
216+
for bix, data in enumerate(dl):
217+
albedo, normal, mask, sh, face = data
218+
if use_cuda:
219+
albedo = albedo.cuda()
220+
normal = normal.cuda()
221+
mask = mask.cuda()
222+
sh = sh.cuda()
223+
face = face.cuda()
224+
225+
# Apply Mask on input image
226+
# face = applyMask(face, mask)
227+
# predicted_face == reconstruction
228+
# predicted_normal, predicted_albedo, predicted_sh, predicted_shading, shading_residual, updated_shading, predicted_face = sfs_net_model(face)
229+
230+
# Apply Mask on input image
231+
# face = applyMask(face, mask)
232+
predicted_normal, albedo_features, predicted_sh, shading_residual = sfs_net_model(face)
233+
234+
fake_albedo = albedo_gen_model(albedo_features)
235+
236+
out_shading = get_shading(predicted_normal, predicted_sh)
237+
updated_shading = out_shading + shading_residual
238+
out_recon = reconstruct_image(updated_shading, fake_albedo)
239+
240+
# albedo recon loss
241+
current_albedo_loss = albedo_loss(fake_albedo, albedo)
242+
current_recon_loss = recon_loss(out_recon, face)
243+
244+
total_loss = lamda_albedo * current_albedo_loss + lamda_recon * current_recon_loss
245+
246+
# Logging for display and debugging purposes
247+
tloss += total_loss.item()
248+
# nloss += current_normal_loss.item()
249+
aloss += current_albedo_loss.item()
250+
# shloss += current_sh_loss.item()
251+
rloss += current_recon_loss.item()
252+
253+
if bix == fix_bix_dump:
254+
# save predictions in log folder
255+
file_name = out_folder + suffix + '_' + str(train_epoch_num) + '_' + str(fix_bix_dump)
256+
# log images
257+
# save_p_normal = get_normal_in_range(predicted_normal)
258+
save_gt_normal = get_normal_in_range(normal)
259+
save_p_normal = predicted_normal
260+
261+
wandb_log_images(wandb, save_p_normal, mask, suffix+' Predicted Normal', train_epoch_num, suffix+' Predicted Normal', path=file_name + '_predicted_normal.png')
262+
wandb_log_images(wandb, fake_albedo, mask, suffix +' Predicted Albedo', train_epoch_num, suffix+' Predicted Albedo', path=file_name + '_predicted_albedo.png')
263+
wandb_log_images(wandb, out_shading, mask, suffix+' Predicted Shading', train_epoch_num, suffix+' Predicted Shading', path=file_name + '_predicted_shading.png', denormalize=False)
264+
wandb_log_images(wandb, shading_residual, mask, suffix+' Predicted Shading Residual', train_epoch_num, suffix+' Predicted Shading Residual', path=file_name + '_predicted_residual_shading.png', denormalize=False)
265+
wandb_log_images(wandb, updated_shading, mask, suffix+' Predicted Updated Shading', train_epoch_num, suffix+' Predicted Updated Shading', path=file_name + '_predicted_updated_shading.png', denormalize=False)
266+
wandb_log_images(wandb, out_recon, mask, suffix+' Predicted face', train_epoch_num, suffix+' Predicted face', path=file_name + '_predicted_face.png', denormalize=False)
267+
wandb_log_images(wandb, face, mask, suffix+' Ground Truth', train_epoch_num, suffix+' Ground Truth', path=file_name + '_gt_face.png')
268+
wandb_log_images(wandb, save_gt_normal, mask, suffix+' Ground Truth Normal', train_epoch_num, suffix+' Ground Normal', path=file_name + '_gt_normal.png')
269+
wandb_log_images(wandb, albedo, mask, suffix+' Ground Truth Albedo', train_epoch_num, suffix+' Ground Albedo', path=file_name + '_gt_albedo.png')
270+
# Get face with real SH
271+
real_sh_face = sfs_net_model.get_face(sh, predicted_normal, fake_albedo)
272+
wandb_log_images(wandb, real_sh_face, mask, 'Val Real SH Predicted Face', train_epoch_num, 'Val Real SH Predicted Face', path=file_name + '_real_sh_face.png')
273+
syn_face = sfs_net_model.get_face(sh, normal, albedo)
274+
wandb_log_images(wandb, syn_face, mask, 'Val Real SH GT Face', train_epoch_num, 'Val Real SH GT Face', path=file_name + '_syn_gt_face.png')
275+
276+
# TODO:
277+
# Dump SH as CSV or TXT file
278+
279+
len_dl = len(dl)
280+
wandb.log({suffix+' Total loss': tloss/len_dl, suffix+'Albedo loss': aloss/len_dl, suffix + 'Recon loss': rloss/len_dl}, step=train_epoch_num)
281+
282+
# return average loss over dataset
283+
return tloss / len_dl, aloss / len_dl, rloss / len_dl
284+
196285
def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_data=None, read_first=None,
197286
batch_size = 10, num_epochs = 10, log_path = './results/metadata/', use_cuda=False, wandb=None,
198287
lr = 0.01, wt_decay=0.005):
@@ -342,7 +431,6 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
342431

343432
loss_d.backward()
344433
d_optimizer.step()
345-
optimizer.step()
346434

347435
# Logging for display and debugging purposes
348436
tloss += total_loss.item()
@@ -384,7 +472,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
384472
wandb_log_images(wandb, real_sh_face, mask, 'Train Real SH Predicted Face', epoch, 'Train Real SH Predicted Face', path=file_name + '_real_sh_face.png')
385473
wandb_log_images(wandb, syn_face, mask, 'Train Real SH GT Face', epoch, 'Train Real SH GT Face', path=file_name + '_syn_gt_face.png')
386474

387-
v_total, v_albedo, v_recon, v_gloss, v_dloss = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_val_dl, gan_real_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
475+
v_total, v_albedo, v_recon, v_gloss, v_dloss = predict_sfsnet_gan(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_val_dl, gan_real_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
388476
out_folder=out_syn_images_dir+'/val/', wandb=wandb, suffix='GAN Val')
389477
# wandb.log({log_prefix + 'Val Total loss': v_total, log_prefix + 'Val Albedo loss': v_albedo, log_prefix + 'Val Recon loss': v_recon})
390478

@@ -394,11 +482,9 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
394482
# Model saving
395483
torch.save(sfs_net_model.state_dict(), model_checkpoint_dir + 'sfs_net_model.pkl')
396484
if epoch % 5 == 0:
397-
t_total, t_albedo, t_recon, t_gloss, t_dloss = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_test_dl, gan_real_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
485+
t_total, t_albedo, t_recon, t_gloss, t_dloss = predict_sfsnet_gan(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_test_dl, gan_real_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
398486
out_folder=out_syn_images_dir + '/test/', wandb=wandb, suffix='GAN Test')
399487

400-
# wandb.log({log_prefix+'Test Total loss': t_total, log_prefix+'Test Albedo loss': t_albedo, log_prefix+'Test Recon loss': t_recon})
401-
402488
print('Test-set results: Total Loss: {}, Albedo Loss: {}, Gan Loss: {}, Dis Loss: {} \n'.format(t_total, t_albedo, t_gloss, t_dloss))
403489

404490
def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_data=None, read_first=None,
@@ -423,15 +509,6 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
423509
syn_val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
424510
syn_test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
425511

426-
gan_real_train_dataset, gan_real_val_dataset = get_sfsnet_dataset(syn_dir=syn_data+'train/', read_from_csv=syn_train_csv, read_celeba_csv=None, read_first=read_first, validation_split=2)
427-
gan_real_test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_data+'test/', read_from_csv=syn_test_csv, read_celeba_csv=None, read_first=read_first, validation_split=0)
428-
429-
gan_real_train_dl = DataLoader(gan_real_train_dataset, batch_size=batch_size, shuffle=True)
430-
train_real_gan_iter = iter(gan_real_train_dl)
431-
432-
gan_real_val_dl = DataLoader(gan_real_val_dataset, batch_size=batch_size, shuffle=True)
433-
gan_real_test_dl = DataLoader(gan_real_test_dataset, batch_size=batch_size, shuffle=True)
434-
435512
print('Synthetic dataset: Train data: ', len(syn_train_dl), ' Val data: ', len(syn_val_dl), ' Test data: ', len(syn_test_dl))
436513

437514
model_checkpoint_dir = log_path + 'checkpoints/'
@@ -448,32 +525,24 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
448525
optimizer = torch.optim.Adam(model_parameters, lr=lr) #, weight_decay=wt_decay)
449526
albedo_loss = nn.SmoothL1Loss() #nn.L1Loss()
450527
recon_loss = nn.SmoothL1Loss() #nn.L1Loss()
451-
gan_loss = torch.nn.MSELoss()
452-
453-
# Collect and initialize gen-dis optimizers
454-
g_optimizer = torch.optim.Adam(albedo_gen_model.parameters(), lr=lr)
455-
d_optimizer = torch.optim.Adam(albedo_dis_model.parameters(), lr=lr)
456528

457529
if use_cuda:
458530
albedo_loss = albedo_loss.cuda()
459531
recon_loss = recon_loss.cuda()
460532

461533
lamda_recon = 1
462-
lamda_albedo = 10
534+
lamda_albedo = 1
463535

464536
if use_cuda:
465537
albedo_loss = albedo_loss.cuda()
466538
recon_loss = recon_loss.cuda()
467-
gan_loss = gan_loss.cuda()
468539

469540
syn_train_len = len(syn_train_dl)
470541

471542
for epoch in range(1, num_epochs+1):
472543
tloss = 0 # Total loss
473544
aloss = 0 # Albedo loss
474545
rloss = 0 # Reconstruction loss
475-
ganloss = 0 # Gan Loss
476-
disloss = 0 # Dis Loss
477546

478547
for bix, data in enumerate(syn_train_dl):
479548
albedo, normal, mask, sh, face = data
@@ -541,15 +610,15 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
541610
wandb_log_images(wandb, real_sh_face, mask, 'Train Real SH Predicted Face', epoch, 'Train Real SH Predicted Face', path=file_name + '_real_sh_face.png')
542611
wandb_log_images(wandb, syn_face, mask, 'Train Real SH GT Face', epoch, 'Train Real SH GT Face', path=file_name + '_syn_gt_face.png')
543612

544-
v_total, v_albedo, v_recon, v_gloss, v_dloss = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_val_dl, gan_real_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
613+
v_total, v_albedo, v_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
545614
out_folder=out_syn_images_dir+'/val/', wandb=wandb)
546615
# wandb.log({log_prefix + 'Val Total loss': v_total, log_prefix + 'Val Albedo loss': v_albedo, log_prefix + 'Val Recon loss': v_recon})
547616
print('Val set results: Total Loss: {}, Albedo Loss: {}, Recon Loss: {}'.format(v_total, v_albedo, v_recon))
548617

549618
# Model saving
550619
torch.save(sfs_net_model.state_dict(), model_checkpoint_dir + 'sfs_net_model.pkl')
551620
if epoch % 5 == 0:
552-
t_total, t_albedo, t_recon, t_gloss, t_dloss = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_test_dl, gan_real_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
621+
t_total, t_albedo, t_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
553622
out_folder=out_syn_images_dir + '/test/', wandb=wandb, suffix='Test')
554623

555624
# wandb.log({log_prefix+'Test Total loss': t_total, log_prefix+'Test Albedo loss': t_albedo, log_prefix+'Test Recon loss': t_recon})

0 commit comments

Comments
 (0)