@@ -33,11 +33,11 @@ def define_flags():
3333 "Support classification, regression" )
3434 flags .DEFINE_integer ("feature_size" , 9 , "Number of feature size" )
3535 flags .DEFINE_integer ("label_size" , 2 , "Number of label size" )
36- flags .DEFINE_string ("train_file_format " , "tfrecords" ,
37- "Support tfrecords, csv" )
38- flags . DEFINE_string ( "train_file" , "./data/cancer/cancer_train.csv.tfrecords" ,
36+ flags .DEFINE_string ("file_format " , "tfrecords" , "Support tfrecords, csv" )
37+ flags . DEFINE_string ( "train_files" ,
38+ "./data/cancer/cancer_train.csv.tfrecords" ,
3939 "Train files which supports glob pattern" )
40- flags .DEFINE_string ("validation_file " ,
40+ flags .DEFINE_string ("validation_files " ,
4141 "./data/cancer/cancer_test.csv.tfrecords" ,
4242 "Validate files which supports glob pattern" )
4343 flags .DEFINE_string ("inference_data_file" , "./data/cancer/cancer_test.csv" ,
@@ -71,7 +71,7 @@ def define_flags():
7171 # Check parameters
7272 assert (FLAGS .mode in ["train" , "inference" , "savedmodel" ])
7373 assert (FLAGS .scenario in ["classification" , "regression" ])
74- assert (FLAGS .train_file_format in ["tfrecords" , "csv" ])
74+ assert (FLAGS .file_format in ["tfrecords" , "csv" ])
7575 assert (FLAGS .optimizer in [
7676 "sgd" , "adadelta" , "adagrad" , "adam" , "ftrl" , "rmsprop"
7777 ])
@@ -207,7 +207,7 @@ def parse_tfrecords_function(example_proto):
207207
208208
209209# TODO: Change for dataset api
210- def read_and_decode_csv (filename_queue ):
210+ def read_and_decode_csv_old (filename_queue ):
211211 # Notice that it supports label in the last column only
212212 reader = tf .TextLineReader ()
213213 key , value = reader .read (filename_queue )
@@ -218,6 +218,31 @@ def read_and_decode_csv(filename_queue):
218218 return label , features
219219
220220
221+ def parse_csv_function (line ):
222+ # Metadata describing the text columns
223+ COLUMNS = [
224+ "feature0" , "feature1" , "feature2" , "feature3" , "feature4" , "feature5" ,
225+ "feature6" , "feature7" , "feature8" , "label"
226+ ]
227+ FIELD_DEFAULTS = [[0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ], [0.0 ],
228+ [0.0 ], [0 ]]
229+
230+ # Decode the line into its fields
231+ fields = tf .decode_csv (line , FIELD_DEFAULTS )
232+
233+ # Pack the result into a dictionary
234+ #features = dict(zip(COLUMNS,fields))
235+
236+ # Separate the label from the features
237+ #label = features.pop("label")
238+
239+ label = fields [- 1 ]
240+ label = tf .cast (label , tf .int64 )
241+ features = tf .stack (fields [0 :- 1 ])
242+
243+ return features , label
244+
245+
221246def inference (inputs , input_units , output_units , is_train = True ):
222247 """
223248 Define the model by model name.
@@ -269,33 +294,43 @@ def main():
269294 train_buffer_size = FLAGS .train_batch_size * 3
270295 validation_buffer_size = FLAGS .train_batch_size * 3
271296
272- train_filename_list = [FLAGS .train_file ]
297+ train_filename_list = [filename for filename in FLAGS .train_files . split ( "," ) ]
273298 train_filename_placeholder = tf .placeholder (tf .string , shape = [None ])
274- train_dataset = tf .data .TFRecordDataset (train_filename_placeholder )
275- train_dataset = train_dataset .map (parse_tfrecords_function ).repeat (
276- epoch_number ).batch (FLAGS .train_batch_size ).shuffle (
277- buffer_size = train_buffer_size )
299+ if FLAGS .file_format == "tfrecords" :
300+ train_dataset = tf .data .TFRecordDataset (train_filename_placeholder )
301+ train_dataset = train_dataset .map (parse_tfrecords_function ).repeat (
302+ epoch_number ).batch (FLAGS .train_batch_size ).shuffle (
303+ buffer_size = train_buffer_size )
304+ elif FLAGS .file_format == "csv" :
305+ # Skip the header or not
306+ train_dataset = tf .data .TextLineDataset (train_filename_placeholder )
307+ train_dataset = train_dataset .map (parse_csv_function ).repeat (
308+ epoch_number ).batch (FLAGS .train_batch_size ).shuffle (
309+ buffer_size = train_buffer_size )
278310 train_dataset_iterator = train_dataset .make_initializable_iterator ()
279311 train_features_op , train_label_op = train_dataset_iterator .get_next ()
280312
281- validation_filename_list = [FLAGS .validation_file ]
313+ validation_filename_list = [
314+ filename for filename in FLAGS .validation_files .split ("," )
315+ ]
282316 validation_filename_placeholder = tf .placeholder (tf .string , shape = [None ])
283- validation_dataset = tf .data .TFRecordDataset (validation_filename_placeholder )
284- validation_dataset = validation_dataset .map (parse_tfrecords_function ).repeat (
285- epoch_number ).batch (FLAGS .validation_batch_size ).shuffle (
286- buffer_size = validation_buffer_size )
317+ if FLAGS .file_format == "tfrecords" :
318+ validation_dataset = tf .data .TFRecordDataset (
319+ validation_filename_placeholder )
320+ validation_dataset = validation_dataset .map (
321+ parse_tfrecords_function ).repeat (epoch_number ).batch (
322+ FLAGS .validation_batch_size ).shuffle (
323+ buffer_size = validation_buffer_size )
324+ elif FLAGS .file_format == "csv" :
325+ validation_dataset = tf .data .TextLineDataset (
326+ validation_filename_placeholder )
327+ validation_dataset = validation_dataset .map (parse_csv_function ).repeat (
328+ epoch_number ).batch (FLAGS .validation_batch_size ).shuffle (
329+ buffer_size = validation_buffer_size )
287330 validation_dataset_iterator = validation_dataset .make_initializable_iterator (
288331 )
289332 validation_features_op , validation_label_op = validation_dataset_iterator .get_next (
290333 )
291- """
292- if FLAGS.train_file_format == "tfrecords":
293- pass
294- #read_and_decode_function = read_and_decode_tfrecords
295- elif FLAGS.train_file_format == "csv":
296- pass
297- #read_and_decode_function = read_and_decode_csv
298- """
299334
300335 # Step 2: Define the model
301336 input_units = FLAGS .feature_size
0 commit comments