@@ -31,6 +31,9 @@ def define_flags():
3131 flags .DEFINE_boolean ("resume_from_checkpoint" , True , "Resume or not" )
3232 flags .DEFINE_string ("scenario" , "classification" ,
3333 "Support classification, regression" )
34+ flags .DEFINE_string (
35+ "loss" , "sparse_cross_entropy" ,
36+ "Support sparse_cross_entropy, cross_entropy, mean_square" )
3437 flags .DEFINE_integer ("feature_size" , 9 , "Number of feature size" )
3538 flags .DEFINE_integer ("label_size" , 2 , "Number of label size" )
3639 flags .DEFINE_string ("file_format" , "tfrecords" , "Support tfrecords, csv" )
@@ -47,8 +50,10 @@ def define_flags():
4750 flags .DEFINE_string ("optimizer" , "adagrad" ,
4851 "Support sgd, adadelta, adagrad, adam, ftrl, rmsprop" )
4952 flags .DEFINE_float ("learning_rate" , 0.01 , "Learning rate" )
50- flags .DEFINE_string ("model" , "dnn" ,
51- "Support dnn, lr, wide_and_deep, customized, cnn" )
53+ flags .DEFINE_string (
54+ "model" , "dnn" ,
55+ "Support dnn, lr, wide_and_deep, customized, cnn, lstm, bidirectional_lstm, gru"
56+ )
5257 flags .DEFINE_string ("dnn_struct" , "128 32 8" , "DNN struct" )
5358 flags .DEFINE_integer ("epoch_number" , 100 , "Number of epoches" )
5459 flags .DEFINE_integer ("train_batch_size" , 64 , "Batch size" )
@@ -71,11 +76,17 @@ def define_flags():
7176 # Check parameters
7277 assert (FLAGS .mode in ["train" , "inference" , "savedmodel" ])
7378 assert (FLAGS .scenario in ["classification" , "regression" ])
79+ assert (FLAGS .loss in [
80+ "sparse_cross_entropy" , "cross_entropy" , "mean_square"
81+ ])
7482 assert (FLAGS .file_format in ["tfrecords" , "csv" ])
7583 assert (FLAGS .optimizer in [
7684 "sgd" , "adadelta" , "adagrad" , "adam" , "ftrl" , "rmsprop"
7785 ])
78- assert (FLAGS .model in ["dnn" , "lr" , "wide_and_deep" , "customized" , "cnn" ])
86+ assert (FLAGS .model in [
87+ "dnn" , "lr" , "wide_and_deep" , "customized" , "cnn" , "customized_cnn" ,
88+ "lstm" , "bidirectional_lstm" , "gru"
89+ ])
7990
8091 # Print flags
8192 parameter_value_map = {}
@@ -206,36 +217,22 @@ def parse_tfrecords_function(example_proto):
206217 return parsed_features ["features" ], parsed_features ["label" ]
207218
208219
209- # TODO: Change for dataset api
210- def read_and_decode_csv_old (filename_queue ):
211- # Notice that it supports label in the last column only
212- reader = tf .TextLineReader ()
213- key , value = reader .read (filename_queue )
214- record_defaults = [[1.0 ] for i in range (FLAGS .feature_size )] + [[0 ]]
215- columns = tf .decode_csv (value , record_defaults = record_defaults )
216- label = columns [- 1 ]
217- features = tf .stack (columns [0 :- 1 ])
218- return label , features
219-
220-
221220def parse_csv_function (line ):
222- # Metadata describing the text columns
223- COLUMNS = [
224- "feature0" , "feature1" , "feature2" , "feature3" , "feature4" , "feature5" ,
225- "feature6" , "feature7" , "feature8" , "label"
226- ]
221+ """
222+ Decode CSV for Dataset.
223+
224+ Args:
225+ line: One line data of the CSV.
226+
227+ Return:
228+ The op of features and labels
229+ """
230+
227231 FIELD_DEFAULTS = [[0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ],
228232 [0.0 ], [0 ]]
229233
230- # Decode the line into its fields
231234 fields = tf .decode_csv (line , FIELD_DEFAULTS )
232235
233- # Pack the result into a dictionary
234- #features = dict(zip(COLUMNS,fields))
235-
236- # Separate the label from the features
237- #label = features.pop("label")
238-
239236 label = fields [- 1 ]
240237 label = tf .cast (label , tf .int64 )
241238 features = tf .stack (fields [0 :- 1 ])
@@ -266,6 +263,18 @@ def inference(inputs, input_units, output_units, is_train=True):
266263 elif FLAGS .model == "cnn" :
267264 return model .cnn_inference (inputs , input_units , output_units , is_train ,
268265 FLAGS )
266+ elif FLAGS .model == "customized_cnn" :
267+ return model .customized_cnn_inference (inputs , input_units , output_units ,
268+ is_train , FLAGS )
269+ elif FLAGS .model == "lstm" :
270+ return model .lstm_inference (inputs , input_units , output_units , is_train ,
271+ FLAGS )
272+ elif FLAGS .model == "bidirectional_lstm" :
273+ return model .bidirectional_lstm_inference (inputs , input_units ,
274+ output_units , is_train , FLAGS )
275+ elif FLAGS .model == "gru" :
276+ return model .gru_inference (inputs , input_units , output_units , is_train ,
277+ FLAGS )
269278
270279
271280logging .basicConfig (level = logging .INFO )
@@ -337,11 +346,19 @@ def main():
337346 output_units = FLAGS .label_size
338347 logits = inference (train_features_op , input_units , output_units , True )
339348
340- if FLAGS .scenario == "classification " :
349+ if FLAGS .loss == "sparse_cross_entropy " :
341350 cross_entropy = tf .nn .sparse_softmax_cross_entropy_with_logits (
342351 logits = logits , labels = train_label_op )
343352 loss = tf .reduce_mean (cross_entropy , name = "loss" )
344- elif FLAGS .scenario == "regression" :
353+ elif FLAGS .loss == "cross_entropy" :
354+
355+ #train_label_op =
356+ #validation_label_op =
357+
358+ cross_entropy = tf .nn .softmax_cross_entropy_with_logits (
359+ logits = logits , labels = train_label_op )
360+ loss = tf .reduce_mean (cross_entropy , name = "loss" )
361+ elif FLAGS .loss == "mean_square" :
345362 msl = tf .square (logits - train_label_op , name = "msl" )
346363 loss = tf .reduce_mean (msl , name = "loss" )
347364
0 commit comments