Skip to content

Commit 463cad0

Browse files
committed
Support lstm, bidirectional-lstm, gru models
1 parent cd1561c commit 463cad0

2 files changed

Lines changed: 175 additions & 30 deletions

File tree

dense_classifier.py

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def define_flags():
3131
flags.DEFINE_boolean("resume_from_checkpoint", True, "Resume or not")
3232
flags.DEFINE_string("scenario", "classification",
3333
"Support classification, regression")
34+
flags.DEFINE_string(
35+
"loss", "sparse_cross_entropy",
36+
"Support sparse_cross_entropy, cross_entropy, mean_square")
3437
flags.DEFINE_integer("feature_size", 9, "Number of feature size")
3538
flags.DEFINE_integer("label_size", 2, "Number of label size")
3639
flags.DEFINE_string("file_format", "tfrecords", "Support tfrecords, csv")
@@ -47,8 +50,10 @@ def define_flags():
4750
flags.DEFINE_string("optimizer", "adagrad",
4851
"Support sgd, adadelta, adagrad, adam, ftrl, rmsprop")
4952
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
50-
flags.DEFINE_string("model", "dnn",
51-
"Support dnn, lr, wide_and_deep, customized, cnn")
53+
flags.DEFINE_string(
54+
"model", "dnn",
55+
"Support dnn, lr, wide_and_deep, customized, cnn, lstm, bidirectional_lstm, gru"
56+
)
5257
flags.DEFINE_string("dnn_struct", "128 32 8", "DNN struct")
5358
flags.DEFINE_integer("epoch_number", 100, "Number of epoches")
5459
flags.DEFINE_integer("train_batch_size", 64, "Batch size")
@@ -71,11 +76,17 @@ def define_flags():
7176
# Check parameters
7277
assert (FLAGS.mode in ["train", "inference", "savedmodel"])
7378
assert (FLAGS.scenario in ["classification", "regression"])
79+
assert (FLAGS.loss in [
80+
"sparse_cross_entropy", "cross_entropy", "mean_square"
81+
])
7482
assert (FLAGS.file_format in ["tfrecords", "csv"])
7583
assert (FLAGS.optimizer in [
7684
"sgd", "adadelta", "adagrad", "adam", "ftrl", "rmsprop"
7785
])
78-
assert (FLAGS.model in ["dnn", "lr", "wide_and_deep", "customized", "cnn"])
86+
assert (FLAGS.model in [
87+
"dnn", "lr", "wide_and_deep", "customized", "cnn", "customized_cnn",
88+
"lstm", "bidirectional_lstm", "gru"
89+
])
7990

8091
# Print flags
8192
parameter_value_map = {}
@@ -206,36 +217,22 @@ def parse_tfrecords_function(example_proto):
206217
return parsed_features["features"], parsed_features["label"]
207218

208219

209-
# TODO: Change for dataset api
210-
def read_and_decode_csv_old(filename_queue):
211-
# Notice that it supports label in the last column only
212-
reader = tf.TextLineReader()
213-
key, value = reader.read(filename_queue)
214-
record_defaults = [[1.0] for i in range(FLAGS.feature_size)] + [[0]]
215-
columns = tf.decode_csv(value, record_defaults=record_defaults)
216-
label = columns[-1]
217-
features = tf.stack(columns[0:-1])
218-
return label, features
219-
220-
221220
def parse_csv_function(line):
222-
# Metadata describing the text columns
223-
COLUMNS = [
224-
"feature0", "feature1", "feature2", "feature3", "feature4", "feature5",
225-
"feature6", "feature7", "feature8", "label"
226-
]
221+
"""
222+
Decode CSV for Dataset.
223+
224+
Args:
225+
line: One line data of the CSV.
226+
227+
Return:
228+
The op of features and labels
229+
"""
230+
227231
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0],
228232
[0.0], [0]]
229233

230-
# Decode the line into its fields
231234
fields = tf.decode_csv(line, FIELD_DEFAULTS)
232235

233-
# Pack the result into a dictionary
234-
#features = dict(zip(COLUMNS,fields))
235-
236-
# Separate the label from the features
237-
#label = features.pop("label")
238-
239236
label = fields[-1]
240237
label = tf.cast(label, tf.int64)
241238
features = tf.stack(fields[0:-1])
@@ -266,6 +263,18 @@ def inference(inputs, input_units, output_units, is_train=True):
266263
elif FLAGS.model == "cnn":
267264
return model.cnn_inference(inputs, input_units, output_units, is_train,
268265
FLAGS)
266+
elif FLAGS.model == "customized_cnn":
267+
return model.customized_cnn_inference(inputs, input_units, output_units,
268+
is_train, FLAGS)
269+
elif FLAGS.model == "lstm":
270+
return model.lstm_inference(inputs, input_units, output_units, is_train,
271+
FLAGS)
272+
elif FLAGS.model == "bidirectional_lstm":
273+
return model.bidirectional_lstm_inference(inputs, input_units,
274+
output_units, is_train, FLAGS)
275+
elif FLAGS.model == "gru":
276+
return model.gru_inference(inputs, input_units, output_units, is_train,
277+
FLAGS)
269278

270279

271280
logging.basicConfig(level=logging.INFO)
@@ -337,11 +346,19 @@ def main():
337346
output_units = FLAGS.label_size
338347
logits = inference(train_features_op, input_units, output_units, True)
339348

340-
if FLAGS.scenario == "classification":
349+
if FLAGS.loss == "sparse_cross_entropy":
341350
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
342351
logits=logits, labels=train_label_op)
343352
loss = tf.reduce_mean(cross_entropy, name="loss")
344-
elif FLAGS.scenario == "regression":
353+
elif FLAGS.loss == "cross_entropy":
354+
355+
#train_label_op =
356+
#validation_label_op =
357+
358+
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
359+
logits=logits, labels=train_label_op)
360+
loss = tf.reduce_mean(cross_entropy, name="loss")
361+
elif FLAGS.loss == "mean_square":
345362
msl = tf.square(logits - train_label_op, name="msl")
346363
loss = tf.reduce_mean(msl, name="loss")
347364

model.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,44 @@ def cnn_inference(inputs, input_units, output_units, is_train=True,
129129
Define the CNN model.
130130
"""
131131

132+
# [BATCH_SIZE, 9] -> [BATCH_SIZE, 3, 3, 1]
133+
inputs = tf.reshape(inputs, [-1, 3, 3, 1])
134+
135+
# [BATCH_SIZE, 3, 3, 1] -> [BATCH_SIZE, 3, 3, 8]
136+
with tf.variable_scope("conv_0"):
137+
weights = tf.get_variable(
138+
"weights", [3, 3, 1, 8], initializer=tf.random_normal_initializer())
139+
bias = tf.get_variable(
140+
"bias", [8], initializer=tf.random_normal_initializer())
141+
142+
layer = tf.nn.conv2d(inputs, weights, strides=[1, 1, 1, 1], padding="SAME")
143+
layer = tf.nn.bias_add(layer, bias)
144+
layer = tf.nn.relu(layer)
145+
146+
# [BATCH_SIZE, 3, 3, 8] -> [BATCH_SIZE, 3 * 3 * 8]
147+
layer = tf.reshape(layer, [-1, 3 * 3 * 8])
148+
149+
# [BATCH_SIZE, 3 * 3 * 8] -> [BATCH_SIZE, LABEL_SIZE]
150+
with tf.variable_scope("output_layer"):
151+
weights = tf.get_variable(
152+
"weights", [3 * 3 * 8, FLAGS.label_size],
153+
initializer=tf.random_normal_initializer())
154+
bias = tf.get_variable(
155+
"bias", [FLAGS.label_size], initializer=tf.random_normal_initializer())
156+
layer = tf.add(tf.matmul(layer, weights), bias)
157+
158+
return layer
159+
160+
161+
def customized_cnn_inference(inputs,
162+
input_units,
163+
output_units,
164+
is_train=True,
165+
FLAGS=None):
166+
"""
167+
Define the CNN model.
168+
"""
169+
132170
# TODO: Change if validate_batch_size is different
133171
# [BATCH_SIZE, 512 * 512 * 1] -> [BATCH_SIZE, 512, 512, 1]
134172
inputs = tf.reshape(inputs, [FLAGS.train_batch_size, 512, 512, 1])
@@ -187,6 +225,96 @@ def cnn_inference(inputs, input_units, output_units, is_train=True,
187225
return layer
188226

189227

228+
def lstm_inference(inputs,
229+
input_units,
230+
output_units,
231+
is_train=True,
232+
FLAGS=None):
233+
234+
RNN_HIDDEN_UNITS = 128
235+
timesteps = 3
236+
number_input = 3
237+
238+
weights = tf.Variable(tf.random_normal([RNN_HIDDEN_UNITS, output_units]))
239+
biases = tf.Variable(tf.random_normal([output_units]))
240+
241+
# [BATCH_SIZE, 9] -> [BATCH_SIZE, 3, 3]
242+
x = tf.reshape(inputs, [-1, timesteps, number_input])
243+
244+
# [BATCH_SIZE, 3, 3] -> 3 * [BATCH_SIZE, 3]
245+
x = tf.unstack(x, timesteps, 1)
246+
247+
# output size is 128, state size is (c=128, h=128)
248+
lstm_cell = tf.contrib.rnn.BasicLSTMCell(RNN_HIDDEN_UNITS, forget_bias=1.0)
249+
250+
# outputs is array of 3 * [BATCH_SIZE, 3]
251+
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
252+
253+
# outputs[-1] is [BATCH_SIZE, 3]
254+
layer = tf.matmul(outputs[-1], weights) + biases
255+
return layer
256+
257+
258+
def bidirectional_lstm_inference(inputs,
259+
input_units,
260+
output_units,
261+
is_train=True,
262+
FLAGS=None):
263+
264+
RNN_HIDDEN_UNITS = 128
265+
timesteps = 3
266+
number_input = 3
267+
268+
weights = tf.Variable(tf.random_normal([RNN_HIDDEN_UNITS, output_units]))
269+
biases = tf.Variable(tf.random_normal([output_units]))
270+
271+
# [BATCH_SIZE, 9] -> [BATCH_SIZE, 3, 3]
272+
x = tf.reshape(inputs, [-1, timesteps, number_input])
273+
274+
# [BATCH_SIZE, 3, 3] -> 3 * [BATCH_SIZE, 3]
275+
x = tf.unstack(x, timesteps, 1)
276+
277+
# Update the hidden units for bidirection-rnn
278+
fw_lstm_cell = tf.contrib.rnn.BasicLSTMCell(
279+
RNN_HIDDEN_UNITS / 2, forget_bias=1.0)
280+
bw_lstm_cell = tf.contrib.rnn.BasicLSTMCell(
281+
RNN_HIDDEN_UNITS / 2, forget_bias=1.0)
282+
283+
outputs, _, _ = tf.contrib.rnn.static_bidirectional_rnn(
284+
fw_lstm_cell, bw_lstm_cell, x, dtype=tf.float32)
285+
286+
# outputs[-1] is [BATCH_SIZE, 3]
287+
layer = tf.matmul(outputs[-1], weights) + biases
288+
return layer
289+
290+
291+
def gru_inference(inputs, input_units, output_units, is_train=True,
292+
FLAGS=None):
293+
294+
RNN_HIDDEN_UNITS = 128
295+
timesteps = 3
296+
number_input = 3
297+
298+
weights = tf.Variable(tf.random_normal([RNN_HIDDEN_UNITS, output_units]))
299+
biases = tf.Variable(tf.random_normal([output_units]))
300+
301+
# [BATCH_SIZE, 9] -> [BATCH_SIZE, 3, 3]
302+
x = tf.reshape(inputs, [-1, timesteps, number_input])
303+
304+
# [BATCH_SIZE, 3, 3] -> 3 * [BATCH_SIZE, 3]
305+
x = tf.unstack(x, timesteps, 1)
306+
307+
# output size is 128, state size is (c=128, h=128)
308+
lstm_cell = tf.contrib.rnn.GRUCell(RNN_HIDDEN_UNITS)
309+
310+
# outputs is array of 3 * [BATCH_SIZE, 3]
311+
outputs, states = tf.contrib.rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
312+
313+
# outputs[-1] is [BATCH_SIZE, 3]
314+
layer = tf.matmul(outputs[-1], weights) + biases
315+
return layer
316+
317+
190318
def compute_softmax_and_accuracy(logits, labels):
191319
"""
192320
Compute the softmax and accuracy of the logits and labels.
@@ -227,4 +355,4 @@ def compute_auc(softmax_op, label_op, label_size):
227355
new_batch_labels = tf.sparse_to_dense(concated, outshape, 1.0, 0.0)
228356
_, auc_op = tf.contrib.metrics.streaming_auc(softmax_op, new_batch_labels)
229357

230-
return auc_op
358+
return auc_op

0 commit comments

Comments
 (0)