Skip to content

Commit cd1561c

Browse files
committed
Support reading csv with dataset api
1 parent 8ca0468 commit cd1561c

1 file changed

Lines changed: 59 additions & 24 deletions

File tree

dense_classifier.py

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def define_flags():
3333
"Support classification, regression")
3434
flags.DEFINE_integer("feature_size", 9, "Number of feature size")
3535
flags.DEFINE_integer("label_size", 2, "Number of label size")
36-
flags.DEFINE_string("train_file_format", "tfrecords",
37-
"Support tfrecords, csv")
38-
flags.DEFINE_string("train_file", "./data/cancer/cancer_train.csv.tfrecords",
36+
flags.DEFINE_string("file_format", "tfrecords", "Support tfrecords, csv")
37+
flags.DEFINE_string("train_files",
38+
"./data/cancer/cancer_train.csv.tfrecords",
3939
"Train files which supports glob pattern")
40-
flags.DEFINE_string("validation_file",
40+
flags.DEFINE_string("validation_files",
4141
"./data/cancer/cancer_test.csv.tfrecords",
4242
"Validate files which supports glob pattern")
4343
flags.DEFINE_string("inference_data_file", "./data/cancer/cancer_test.csv",
@@ -71,7 +71,7 @@ def define_flags():
7171
# Check parameters
7272
assert (FLAGS.mode in ["train", "inference", "savedmodel"])
7373
assert (FLAGS.scenario in ["classification", "regression"])
74-
assert (FLAGS.train_file_format in ["tfrecords", "csv"])
74+
assert (FLAGS.file_format in ["tfrecords", "csv"])
7575
assert (FLAGS.optimizer in [
7676
"sgd", "adadelta", "adagrad", "adam", "ftrl", "rmsprop"
7777
])
@@ -207,7 +207,7 @@ def parse_tfrecords_function(example_proto):
207207

208208

209209
# TODO: Change for dataset api
210-
def read_and_decode_csv(filename_queue):
210+
def read_and_decode_csv_old(filename_queue):
211211
# Notice that it supports label in the last column only
212212
reader = tf.TextLineReader()
213213
key, value = reader.read(filename_queue)
@@ -218,6 +218,31 @@ def read_and_decode_csv(filename_queue):
218218
return label, features
219219

220220

221+
def parse_csv_function(line):
222+
# Metadata describing the text columns
223+
COLUMNS = [
224+
"feature0", "feature1", "feature2", "feature3", "feature4", "feature5",
225+
"feature6", "feature7", "feature8", "label"
226+
]
227+
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0],
228+
[0.0], [0]]
229+
230+
# Decode the line into its fields
231+
fields = tf.decode_csv(line, FIELD_DEFAULTS)
232+
233+
# Pack the result into a dictionary
234+
#features = dict(zip(COLUMNS,fields))
235+
236+
# Separate the label from the features
237+
#label = features.pop("label")
238+
239+
label = fields[-1]
240+
label = tf.cast(label, tf.int64)
241+
features = tf.stack(fields[0:-1])
242+
243+
return features, label
244+
245+
221246
def inference(inputs, input_units, output_units, is_train=True):
222247
"""
223248
Define the model by model name.
@@ -269,33 +294,43 @@ def main():
269294
train_buffer_size = FLAGS.train_batch_size * 3
270295
validation_buffer_size = FLAGS.train_batch_size * 3
271296

272-
train_filename_list = [FLAGS.train_file]
297+
train_filename_list = [filename for filename in FLAGS.train_files.split(",")]
273298
train_filename_placeholder = tf.placeholder(tf.string, shape=[None])
274-
train_dataset = tf.data.TFRecordDataset(train_filename_placeholder)
275-
train_dataset = train_dataset.map(parse_tfrecords_function).repeat(
276-
epoch_number).batch(FLAGS.train_batch_size).shuffle(
277-
buffer_size=train_buffer_size)
299+
if FLAGS.file_format == "tfrecords":
300+
train_dataset = tf.data.TFRecordDataset(train_filename_placeholder)
301+
train_dataset = train_dataset.map(parse_tfrecords_function).repeat(
302+
epoch_number).batch(FLAGS.train_batch_size).shuffle(
303+
buffer_size=train_buffer_size)
304+
elif FLAGS.file_format == "csv":
305+
# Skip the header or not
306+
train_dataset = tf.data.TextLineDataset(train_filename_placeholder)
307+
train_dataset = train_dataset.map(parse_csv_function).repeat(
308+
epoch_number).batch(FLAGS.train_batch_size).shuffle(
309+
buffer_size=train_buffer_size)
278310
train_dataset_iterator = train_dataset.make_initializable_iterator()
279311
train_features_op, train_label_op = train_dataset_iterator.get_next()
280312

281-
validation_filename_list = [FLAGS.validation_file]
313+
validation_filename_list = [
314+
filename for filename in FLAGS.validation_files.split(",")
315+
]
282316
validation_filename_placeholder = tf.placeholder(tf.string, shape=[None])
283-
validation_dataset = tf.data.TFRecordDataset(validation_filename_placeholder)
284-
validation_dataset = validation_dataset.map(parse_tfrecords_function).repeat(
285-
epoch_number).batch(FLAGS.validation_batch_size).shuffle(
286-
buffer_size=validation_buffer_size)
317+
if FLAGS.file_format == "tfrecords":
318+
validation_dataset = tf.data.TFRecordDataset(
319+
validation_filename_placeholder)
320+
validation_dataset = validation_dataset.map(
321+
parse_tfrecords_function).repeat(epoch_number).batch(
322+
FLAGS.validation_batch_size).shuffle(
323+
buffer_size=validation_buffer_size)
324+
elif FLAGS.file_format == "csv":
325+
validation_dataset = tf.data.TextLineDataset(
326+
validation_filename_placeholder)
327+
validation_dataset = validation_dataset.map(parse_csv_function).repeat(
328+
epoch_number).batch(FLAGS.validation_batch_size).shuffle(
329+
buffer_size=validation_buffer_size)
287330
validation_dataset_iterator = validation_dataset.make_initializable_iterator(
288331
)
289332
validation_features_op, validation_label_op = validation_dataset_iterator.get_next(
290333
)
291-
"""
292-
if FLAGS.train_file_format == "tfrecords":
293-
pass
294-
#read_and_decode_function = read_and_decode_tfrecords
295-
elif FLAGS.train_file_format == "csv":
296-
pass
297-
#read_and_decode_function = read_and_decode_csv
298-
"""
299334

300335
# Step 2: Define the model
301336
input_units = FLAGS.feature_size

0 commit comments

Comments
 (0)