Skip to content

Commit b168e27

Browse files
committed
Update python predict client with official python package
1 parent a17a972 commit b168e27

6 files changed

Lines changed: 20 additions & 482 deletions

File tree

python_predict_client/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ TensorFlow serving is the gRPC service for general TensorFlow models. We can imp
77
## Usage
88

99
```
10-
./predict_client.py --host 127.0.0.1 --port 9000 --model_name cancer --model_version 1
10+
./predict_client.py --host 127.0.0.1 --port 8500 --model_name default --model_version 1
1111
```
1212

1313
For sparse data, you can run with this command.

python_predict_client/model_pb2.py

Lines changed: 0 additions & 93 deletions
This file was deleted.
Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,46 @@
11
#!/usr/bin/env python
22

33
import numpy
4-
5-
from grpc.beta import implementations
64
import tensorflow as tf
7-
8-
import predict_pb2
9-
import prediction_service_pb2
5+
from grpc.beta import implementations
6+
from tensorflow_serving.apis import predict_pb2
7+
from tensorflow_serving.apis import prediction_service_pb2
108

119
tf.app.flags.DEFINE_string("host", "127.0.0.1", "gRPC server host")
1210
tf.app.flags.DEFINE_integer("port", 9000, "gRPC server port")
13-
tf.app.flags.DEFINE_string("model_name", "cancer", "TensorFlow model name")
11+
tf.app.flags.DEFINE_string("model_name", "default", "TensorFlow model name")
1412
tf.app.flags.DEFINE_integer("model_version", -1, "TensorFlow model version")
13+
tf.app.flags.DEFINE_string("signature_name", "", "The signature name")
1514
tf.app.flags.DEFINE_float("request_timeout", 10.0, "Timeout of gRPC request")
1615
FLAGS = tf.app.flags.FLAGS
1716

1817

1918
def main():
20-
host = FLAGS.host
21-
port = FLAGS.port
22-
model_name = FLAGS.model_name
23-
model_version = FLAGS.model_version
24-
request_timeout = FLAGS.request_timeout
25-
2619
# Generate inference data
2720
keys = numpy.asarray([1, 2, 3])
2821
keys_tensor_proto = tf.contrib.util.make_tensor_proto(keys, dtype=tf.int32)
2922
features = numpy.asarray(
3023
[[1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 1, 1, 1, 1, 1, 1, 1, 1],
3124
[9, 8, 7, 6, 5, 4, 3, 2, 1], [9, 9, 9, 9, 9, 9, 9, 9, 9]])
32-
features_tensor_proto = tf.contrib.util.make_tensor_proto(features,
33-
dtype=tf.float32)
25+
features_tensor_proto = tf.contrib.util.make_tensor_proto(
26+
features, dtype=tf.float32)
3427

35-
# Create gRPC client and request
36-
channel = implementations.insecure_channel(host, port)
28+
# Create gRPC client
29+
channel = implementations.insecure_channel(FLAGS.host, FLAGS.port)
3730
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
3831
request = predict_pb2.PredictRequest()
39-
request.model_spec.name = model_name
40-
if model_version > 0:
41-
request.model_spec.version.value = model_version
42-
request.inputs['keys'].CopyFrom(keys_tensor_proto)
43-
request.inputs['features'].CopyFrom(features_tensor_proto)
32+
request.model_spec.name = FLAGS.model_name
33+
if FLAGS.model_version > 0:
34+
request.model_spec.version.value = FLAGS.model_version
35+
if FLAGS.signature_name != "":
36+
request.model_spec.signature_name = FLAGS.signature_name
37+
request.inputs["keys"].CopyFrom(keys_tensor_proto)
38+
request.inputs["features"].CopyFrom(features_tensor_proto)
4439

4540
# Send request
46-
result = stub.Predict(request, request_timeout)
41+
result = stub.Predict(request, FLAGS.request_timeout)
4742
print(result)
4843

4944

50-
if __name__ == '__main__':
45+
if __name__ == "__main__":
5146
main()

0 commit comments

Comments
 (0)