@@ -40,6 +40,9 @@ def main():
4040 help = 'read first n rows (default: -1)' )
4141 parser .add_argument ('--details' , type = str , default = None ,
4242 help = 'Explaination of the run' )
43+ parser .add_argument ('--load_pretrained_model' , type = str , default = '../pretrained/net_epoch_r5_5.pth' ,
44+ help = 'Pretrained model path' )
45+
4346 if ON_SERVER :
4447 parser .add_argument ('--syn_data' , type = str , default = '/nfs/bigdisk/bsonawane/sfsnet_data/' ,
4548 help = 'Synthetic Dataset path' )
@@ -71,6 +74,7 @@ def main():
7174 epochs = args .epochs
7275 model_dir = args .load_model
7376 read_first = args .read_first
77+ pretrained_model_dict = args .load_pretrained_model
7478
7579
7680 if read_first == - 1 :
@@ -98,12 +102,17 @@ def main():
98102 else :
99103 print ('Initializing weights' )
100104 sfs_net_model .apply (weights_init )
105+ sfs_net_pretrained_dict = torch .load (pretrained_model_dict )
106+ sfs_net_state_dict = sfs_net_model .state_dict ()
107+ load_model_from_pretrained (sfs_net_pretrained_dict , sfs_net_state_dict )
108+ sfs_net_model .load_state_dict (sfs_net_state_dict )
109+ sfs_net_model .fix_weights ()
101110
102111 os .system ('mkdir -p {}' .format (args .log_dir ))
103112 with open (args .log_dir + '/details.txt' , 'w' ) as f :
104113 f .write (args .details )
105114
106- wandb .watch (sfs_net_model )
115+ # wandb.watch(sfs_net_model)
107116
108117 # 1. Train on both Synthetic and Real (Celeba) dataset
109118 train (sfs_net_model , syn_data , celeba_data = celeba_data , read_first = read_first ,\
0 commit comments