Skip to content

Commit b12dae3

Browse files
committed
Print all flags with pprint before training
1 parent de2976b commit b12dae3

3 files changed

Lines changed: 15 additions & 4 deletions

File tree

README.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ Following are the supported features.
2525
- [x] Validate auc
2626
- [x] Inference online
2727
- [x] Network Model
28-
- [x] LR
29-
- [x] DNN
30-
- [x] Wide and deep
31-
- [x] Customized
28+
- [x] Logistic regression
29+
- [x] Deep neural network
30+
- [x] Convolution neural network
31+
- [x] Wide and deep model
32+
- [x] Customized models
3233
- [x] Others
3334
- [x] Checkpoint
3435
- [x] TensorBoard
@@ -82,6 +83,12 @@ If you use other dataset like [iris](./data/iris/), no need to modify the code.
8283
./dense_classifier.py --train_tfrecords_file ./data/iris/iris_train.csv.tfrecords --validate_tfrecords_file ./data/iris/iris_test.csv.tfrecords --feature_size 4 --label_size 3
8384
```
8485

86+
If you want to use CNN model, try this command.
87+
88+
```
89+
./dense_classifier.py --train_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_tfrecords_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn
90+
```
91+
8592
### Export The Model
8693

8794
After training, it will export the model automatically. Or you can export manually.

dense_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import numpy as np
77
import os
8+
import pprint
89
from sklearn import metrics
910
import tensorflow as tf
1011
from tensorflow.contrib.session_bundle import exporter
@@ -74,6 +75,7 @@ def main():
7475
OUTPUT_PATH = FLAGS.output_path
7576
if not OUTPUT_PATH.startswith("fds://") and not os.path.exists(OUTPUT_PATH):
7677
os.makedirs(OUTPUT_PATH)
78+
pprint.PrettyPrinter().pprint(FLAGS.__flags)
7779

7880
# Process TFRecoreds files
7981
def read_and_decode(filename_queue):

sparse_classifier.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import math
66
import numpy as np
77
import os
8+
import pprint
89
from sklearn import metrics
910
import tensorflow as tf
1011
from tensorflow.contrib.session_bundle import exporter
@@ -75,6 +76,7 @@ def main():
7576
OUTPUT_PATH = FLAGS.output_path
7677
if not OUTPUT_PATH.startswith("fds://") and not os.path.exists(OUTPUT_PATH):
7778
os.makedirs(OUTPUT_PATH)
79+
pprint.PrettyPrinter().pprint(FLAGS.__flags)
7880

7981
# Read TFRecords files for training
8082
def read_and_decode(filename_queue):

0 commit comments

Comments
 (0)