Skip to content

Commit 62de7af

Browse files
committed
load and fix wts
1 parent 1a8d7eb commit 62de7af

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

2_Latent-Shading-Gen/main_full.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def main():
4040
help='read first n rows (default: -1)')
4141
parser.add_argument('--details', type=str, default=None,
4242
help='Explaination of the run')
43+
parser.add_argument('--load_pretrained_model', type=str, default='../pretrained/net_epoch_r5_5.pth',
44+
help='Pretrained model path')
45+
4346
if ON_SERVER:
4447
parser.add_argument('--syn_data', type=str, default='/nfs/bigdisk/bsonawane/sfsnet_data/',
4548
help='Synthetic Dataset path')
@@ -71,6 +74,7 @@ def main():
7174
epochs = args.epochs
7275
model_dir = args.load_model
7376
read_first = args.read_first
77+
pretrained_model_dict = args.load_pretrained_model
7478

7579

7680
if read_first == -1:
@@ -98,12 +102,17 @@ def main():
98102
else:
99103
print('Initializing weights')
100104
sfs_net_model.apply(weights_init)
105+
sfs_net_pretrained_dict = torch.load(pretrained_model_dict)
106+
sfs_net_state_dict = sfs_net_model.state_dict()
107+
load_model_from_pretrained(sfs_net_pretrained_dict, sfs_net_state_dict)
108+
sfs_net_model.load_state_dict(sfs_net_state_dict)
109+
sfs_net_model.fix_weights()
101110

102111
os.system('mkdir -p {}'.format(args.log_dir))
103112
with open(args.log_dir+'/details.txt', 'w') as f:
104113
f.write(args.details)
105114

106-
wandb.watch(sfs_net_model)
115+
# wandb.watch(sfs_net_model)
107116

108117
# 1. Train on both Synthetic and Real (Celeba) dataset
109118
train(sfs_net_model, syn_data, celeba_data=celeba_data, read_first=read_first,\

0 commit comments

Comments
 (0)