Skip to content

Commit 87151a3

Browse files
committed
gan separate trainign models
1 parent c5ba316 commit 87151a3

4 files changed

Lines changed: 30 additions & 53 deletions

File tree

6_Shading_Resdiual_GAN_Separate_Train/data_loading.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -119,25 +119,7 @@ def get_sfsnet_dataset(syn_dir=None, read_from_csv=None, read_celeba_csv=None, r
119119
face = []
120120
depth = []
121121

122-
if read_from_csv is None:
123-
for img in sorted(glob.glob(syn_dir + '*/*_albedo_*')):
124-
albedo.append(img)
125-
126-
for img in sorted(glob.glob(syn_dir + '*/*_face_*')):
127-
face.append(img)
128-
129-
for img in sorted(glob.glob(syn_dir + '*/*_normal_*')):
130-
normal.append(img)
131-
132-
for img in sorted(glob.glob(syn_dir + '*/*_depth_*')):
133-
depth.append(img)
134-
135-
for img in sorted(glob.glob(syn_dir + '*/*_mask_*')):
136-
mask.append(img)
137-
138-
for img in sorted(glob.glob(syn_dir + '*/*_light_*.txt')):
139-
sh.append(img)
140-
else:
122+
if read_from_csv is not None:
141123
df = pd.read_csv(read_from_csv)
142124
if read_first is not None and len(df) > read_first:
143125
df = df.sample(read_first, random_state=100)
@@ -155,17 +137,17 @@ def get_sfsnet_dataset(syn_dir=None, read_from_csv=None, read_celeba_csv=None, r
155137
for _, v in name_to_list.items():
156138
v[:] = [syn_dir + el for el in v]
157139

158-
# Merge Synthesized Celeba dataset for Psedo-Supervised training
159-
if read_celeba_csv is not None:
160-
df = pd.read_csv(read_celeba_csv)
161-
if read_first is not None and len(df) > read_first:
162-
df = df.sample(read_first, random_state=100)
163-
albedo += list(df['albedo'])
164-
face += list(df['face'])
165-
normal += list(df['normal'])
166-
depth += list(df['depth'])
167-
mask += list(df['mask'])
168-
sh += list(df['light'])
140+
# Merge Synthesized Celeba dataset for Psedo-Supervised training
141+
if read_celeba_csv is not None:
142+
df = pd.read_csv(read_celeba_csv)
143+
if read_first is not None and len(df) > read_first:
144+
df = df.sample(read_first, random_state=100)
145+
albedo += list(df['albedo'])
146+
face += list(df['face'])
147+
normal += list(df['normal'])
148+
depth += list(df['depth'])
149+
mask += list(df['mask'])
150+
sh += list(df['light'])
169151

170152
assert(len(albedo) == len(face) == len(normal) == len(depth) == len(mask) == len(sh))
171153
dataset_size = len(albedo)

6_Shading_Resdiual_GAN_Separate_Train/main_full.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def main():
9090
# return
9191

9292
# Init WandB for logging
93-
wandb.init(project='SfSNet-CelebA-GANLoss-Shading-Residual-PreTrained')
93+
wandb.init(project='SfSNet-CelebA-GAN-Separate-Training-Shading-Residual-PreTrained')
9494
wandb.log({'lr':lr, 'weight decay': wt_decay})
9595

9696
# Initialize models
@@ -101,6 +101,9 @@ def main():
101101

102102
if use_cuda:
103103
sfs_net_model = sfs_net_model.cuda()
104+
albedo_gen_model = albedo_gen_model.cuda()
105+
albedo_dis_model = albedo_dis_model.cuda()
106+
104107

105108
if model_dir is not None:
106109
sfs_net_model.load_state_dict(torch.load(model_dir + 'sfs_net_model.pkl'))
@@ -168,6 +171,7 @@ def main():
168171

169172
# fix the weights of albedo gen model
170173
albedo_gen_model.fix_weights()
174+
sfs_net_model.fix_new_weights()
171175

172176
train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_data=celeba_data, read_first=read_first,\
173177
batch_size=batch_size, num_epochs=epochs, log_path=log_dir+'Mix_Training/', use_cuda=use_cuda, wandb=wandb, \

6_Shading_Resdiual_GAN_Separate_Train/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ def forward(self, x):
228228
out = self.conv3(out)
229229
return out
230230

231+
def fix_weights(self):
232+
dfs_freeze(self)
233+
231234
class LightEstimator(nn.Module):
232235
""" Estimate lighting from normal, albedo and conv features
233236
"""
@@ -333,6 +336,9 @@ def fix_weights(self):
333336
dfs_freeze(self.light_estimator_model)
334337
# Note that we are not freezing Albedo gen model
335338

339+
def fix_new_weights(self):
340+
dfs_freeze(self.albedo_residual_model_2)
341+
336342

337343
# Use following to fix weights of the model
338344
# Ref - https://discuss.pytorch.org/t/how-the-pytorch-freeze-network-in-some-layers-only-the-rest-of-the-training/7088/15

6_Shading_Resdiual_GAN_Separate_Train/train.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
298298

299299
# Load Synthetic dataset
300300
train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_data+'train/', read_from_csv=syn_train_csv, read_celeba_csv=celeba_train_csv, read_first=read_first, validation_split=2)
301-
test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_data+'test/', read_from_csv=syn_test_csv, read_celeba_csv=celeba_test_csv, read_first=100, validation_split=0)
301+
test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_data+'test/', read_from_csv=None, read_celeba_csv=celeba_test_csv, read_first=100, validation_split=0)
302302

303303
syn_train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
304304
syn_val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
@@ -324,16 +324,6 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
324324
os.system('mkdir -p {}'.format(out_syn_images_dir + 'val/'))
325325
os.system('mkdir -p {}'.format(out_syn_images_dir + 'test/'))
326326

327-
# Create Generator and Discriminator
328-
329-
if use_cuda:
330-
albedo_gen_model = albedo_gen_model.cuda()
331-
albedo_dis_model = albedo_dis_model.cuda()
332-
333-
# Init gen and disc
334-
albedo_gen_model.apply(weights_init)
335-
albedo_dis_model.apply(weights_init)
336-
337327
# Collect model parameters
338328
model_parameters = sfs_net_model.parameters()
339329
optimizer = torch.optim.Adam(model_parameters, lr=lr) #, weight_decay=wt_decay)
@@ -345,10 +335,6 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
345335
g_optimizer = torch.optim.Adam(albedo_gen_model.parameters(), lr=lr)
346336
d_optimizer = torch.optim.Adam(albedo_dis_model.parameters(), lr=lr)
347337

348-
if use_cuda:
349-
albedo_loss = albedo_loss.cuda()
350-
recon_loss = recon_loss.cuda()
351-
352338
lamda_recon = 1
353339
lamda_albedo = 10
354340

@@ -377,9 +363,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
377363

378364
# Apply Mask on input image
379365
# face = applyMask(face, mask)
380-
predicted_normal, albedo_features, predicted_sh, shading_residual = sfs_net_model(face)
381-
optimizer.zero_grad()
382-
# GAN Training
366+
# GAN Training
383367
valid = torch.ones(albedo.shape[0], requires_grad = False)
384368
fake = torch.zeros(albedo.shape[0], requires_grad = False)
385369

@@ -400,6 +384,7 @@ def gan_based_train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data,
400384
real_sample = real_sample.cuda()
401385

402386
# GAN loss
387+
predicted_normal, albedo_features, predicted_sh, shading_residual = sfs_net_model(face)
403388
fake_albedo = albedo_gen_model(albedo_features)
404389
pred_fake = albedo_dis_model(fake_albedo)
405390
# print(pred_fake.shape, valid.shape)
@@ -610,15 +595,15 @@ def train(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_data, celeba_da
610595
wandb_log_images(wandb, real_sh_face, mask, 'Train Real SH Predicted Face', epoch, 'Train Real SH Predicted Face', path=file_name + '_real_sh_face.png')
611596
wandb_log_images(wandb, syn_face, mask, 'Train Real SH GT Face', epoch, 'Train Real SH GT Face', path=file_name + '_syn_gt_face.png')
612597

613-
v_total, v_albedo, v_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
598+
v_total, v_albedo, v_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, syn_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
614599
out_folder=out_syn_images_dir+'/val/', wandb=wandb)
615600
# wandb.log({log_prefix + 'Val Total loss': v_total, log_prefix + 'Val Albedo loss': v_albedo, log_prefix + 'Val Recon loss': v_recon})
616601
print('Val set results: Total Loss: {}, Albedo Loss: {}, Recon Loss: {}'.format(v_total, v_albedo, v_recon))
617602

618603
# Model saving
619604
torch.save(sfs_net_model.state_dict(), model_checkpoint_dir + 'sfs_net_model.pkl')
620605
if epoch % 5 == 0:
621-
t_total, t_albedo, t_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, albedo_dis_model, syn_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
606+
t_total, t_albedo, t_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, syn_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
622607
out_folder=out_syn_images_dir + '/test/', wandb=wandb, suffix='Test')
623608

624609
# wandb.log({log_prefix+'Test Total loss': t_total, log_prefix+'Test Albedo loss': t_albedo, log_prefix+'Test Recon loss': t_recon})
@@ -759,7 +744,7 @@ def train_with_shading_loss(sfs_net_model, syn_data, celeba_data=None, read_firs
759744
wandb_log_images(wandb, real_sh_face, mask, 'Train Real SH Predicted Face', epoch, 'Train Real SH Predicted Face', path=file_name + '_real_sh_face.png')
760745
wandb_log_images(wandb, syn_face, mask, 'Train Real SH GT Face', epoch, 'Train Real SH GT Face', path=file_name + '_syn_gt_face.png')
761746

762-
v_total, v_albedo, v_recon = predict_sfsnet(sfs_net_model, syn_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
747+
v_total, v_albedo, v_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, syn_val_dl, train_epoch_num=epoch, use_cuda=use_cuda,
763748
out_folder=out_syn_images_dir+'/val/', wandb=wandb)
764749
wandb.log({log_prefix + 'Val Total loss': v_total, log_prefix + 'Val Albedo loss': v_albedo, log_prefix + 'Val Recon loss': v_recon})
765750

@@ -769,7 +754,7 @@ def train_with_shading_loss(sfs_net_model, syn_data, celeba_data=None, read_firs
769754
# Model saving
770755
torch.save(sfs_net_model.state_dict(), model_checkpoint_dir + 'sfs_net_model.pkl')
771756
if epoch % 5 == 0:
772-
t_total, t_albedo, t_recon = predict_sfsnet(sfs_net_model, syn_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
757+
t_total, t_albedo, t_recon = predict_sfsnet(sfs_net_model, albedo_gen_model, syn_test_dl, train_epoch_num=epoch, use_cuda=use_cuda,
773758
out_folder=out_syn_images_dir + '/test/', wandb=wandb, suffix='Test')
774759

775760
wandb.log({log_prefix+'Test Total loss': t_total, log_prefix+'Test Albedo loss': t_albedo, log_prefix+'Test Recon loss': t_recon})

0 commit comments

Comments
 (0)