Skip to content

Commit 1bd3899

Browse files
committed
Fix number of items in predict client
1 parent b168e27 commit 1bd3899

1 file changed

Lines changed: 8 additions & 9 deletions

File tree

python_predict_client/predict_client.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,20 @@
33
import numpy
44
import tensorflow as tf
55
from grpc.beta import implementations
6-
from tensorflow_serving.apis import predict_pb2
7-
from tensorflow_serving.apis import prediction_service_pb2
8-
9-
tf.app.flags.DEFINE_string("host", "127.0.0.1", "gRPC server host")
10-
tf.app.flags.DEFINE_integer("port", 9000, "gRPC server port")
11-
tf.app.flags.DEFINE_string("model_name", "default", "TensorFlow model name")
12-
tf.app.flags.DEFINE_integer("model_version", -1, "TensorFlow model version")
13-
tf.app.flags.DEFINE_string("signature_name", "", "The signature name")
6+
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2
7+
8+
tf.app.flags.DEFINE_string("host", "0.0.0.0", "TensorFlow Serving server ip")
9+
tf.app.flags.DEFINE_integer("port", 8500, "TensorFlow Serving server port")
10+
tf.app.flags.DEFINE_string("model_name", "default", "The model name")
11+
tf.app.flags.DEFINE_integer("model_version", -1, "The model version")
12+
tf.app.flags.DEFINE_string("signature_name", "", "The model signature name")
1413
tf.app.flags.DEFINE_float("request_timeout", 10.0, "Timeout of gRPC request")
1514
FLAGS = tf.app.flags.FLAGS
1615

1716

1817
def main():
1918
# Generate inference data
20-
keys = numpy.asarray([1, 2, 3])
19+
keys = numpy.asarray([1, 2, 3, 4])
2120
keys_tensor_proto = tf.contrib.util.make_tensor_proto(keys, dtype=tf.int32)
2221
features = numpy.asarray(
2322
[[1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 1, 1, 1, 1, 1, 1, 1, 1],

0 commit comments

Comments
 (0)