|
1 | 1 | #!/usr/bin/env python |
2 | 2 |
|
3 | 3 | import numpy |
4 | | - |
5 | | -from grpc.beta import implementations |
6 | 4 | import tensorflow as tf |
7 | | - |
8 | | -import predict_pb2 |
9 | | -import prediction_service_pb2 |
| 5 | +from grpc.beta import implementations |
| 6 | +from tensorflow_serving.apis import predict_pb2 |
| 7 | +from tensorflow_serving.apis import prediction_service_pb2 |
10 | 8 |
|
11 | 9 | tf.app.flags.DEFINE_string("host", "127.0.0.1", "gRPC server host") |
12 | 10 | tf.app.flags.DEFINE_integer("port", 9000, "gRPC server port") |
13 | | -tf.app.flags.DEFINE_string("model_name", "cancer", "TensorFlow model name") |
| 11 | +tf.app.flags.DEFINE_string("model_name", "default", "TensorFlow model name") |
14 | 12 | tf.app.flags.DEFINE_integer("model_version", -1, "TensorFlow model version") |
| 13 | +tf.app.flags.DEFINE_string("signature_name", "", "The signature name") |
15 | 14 | tf.app.flags.DEFINE_float("request_timeout", 10.0, "Timeout of gRPC request") |
16 | 15 | FLAGS = tf.app.flags.FLAGS |
17 | 16 |
|
18 | 17 |
|
19 | 18 | def main(): |
20 | | - host = FLAGS.host |
21 | | - port = FLAGS.port |
22 | | - model_name = FLAGS.model_name |
23 | | - model_version = FLAGS.model_version |
24 | | - request_timeout = FLAGS.request_timeout |
25 | | - |
26 | 19 | # Generate inference data |
27 | 20 | keys = numpy.asarray([1, 2, 3]) |
28 | 21 | keys_tensor_proto = tf.contrib.util.make_tensor_proto(keys, dtype=tf.int32) |
29 | 22 | features = numpy.asarray( |
30 | 23 | [[1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 1, 1, 1, 1, 1, 1, 1, 1], |
31 | 24 | [9, 8, 7, 6, 5, 4, 3, 2, 1], [9, 9, 9, 9, 9, 9, 9, 9, 9]]) |
32 | | - features_tensor_proto = tf.contrib.util.make_tensor_proto(features, |
33 | | - dtype=tf.float32) |
| 25 | + features_tensor_proto = tf.contrib.util.make_tensor_proto( |
| 26 | + features, dtype=tf.float32) |
34 | 27 |
|
35 | | - # Create gRPC client and request |
36 | | - channel = implementations.insecure_channel(host, port) |
| 28 | + # Create gRPC client |
| 29 | + channel = implementations.insecure_channel(FLAGS.host, FLAGS.port) |
37 | 30 | stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) |
38 | 31 | request = predict_pb2.PredictRequest() |
39 | | - request.model_spec.name = model_name |
40 | | - if model_version > 0: |
41 | | - request.model_spec.version.value = model_version |
42 | | - request.inputs['keys'].CopyFrom(keys_tensor_proto) |
43 | | - request.inputs['features'].CopyFrom(features_tensor_proto) |
| 32 | + request.model_spec.name = FLAGS.model_name |
| 33 | + if FLAGS.model_version > 0: |
| 34 | + request.model_spec.version.value = FLAGS.model_version |
| 35 | + if FLAGS.signature_name != "": |
| 36 | + request.model_spec.signature_name = FLAGS.signature_name |
| 37 | + request.inputs["keys"].CopyFrom(keys_tensor_proto) |
| 38 | + request.inputs["features"].CopyFrom(features_tensor_proto) |
44 | 39 |
|
45 | 40 | # Send request |
46 | | - result = stub.Predict(request, request_timeout) |
| 41 | + result = stub.Predict(request, FLAGS.request_timeout) |
47 | 42 | print(result) |
48 | 43 |
|
49 | 44 |
|
50 | | -if __name__ == '__main__': |
| 45 | +if __name__ == "__main__": |
51 | 46 | main() |
0 commit comments