Skip to content

Commit 5eb2183

Browse files
authored
Merge pull request #18 from sparklxb/master
simplify the creation of golang client predict request
2 parents 9aef111 + a63c8fe commit 5eb2183

2 files changed

Lines changed: 131 additions & 123 deletions

File tree

golang_predict_client/src/predict_client.go

Lines changed: 28 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
framework "tensorflow/core/framework"
99
pb "tensorflow_serving"
1010

11-
google_protobuf "github.com/golang/protobuf/ptypes/wrappers"
1211
"golang.org/x/net/context"
1312
"google.golang.org/grpc"
1413
"google.golang.org/grpc/credentials"
@@ -71,132 +70,38 @@ func main() {
7170
}
7271

7372
func newDensePredictRequest(modelName *string, modelVersion *int64) *pb.PredictRequest {
74-
return &pb.PredictRequest{
75-
ModelSpec: &pb.ModelSpec{
76-
Name: *modelName,
77-
Version: &google_protobuf.Int64Value{
78-
Value: *modelVersion,
79-
},
80-
},
81-
Inputs: map[string]*framework.TensorProto{
82-
"keys": &framework.TensorProto{
83-
Dtype: framework.DataType_DT_INT32,
84-
TensorShape: &framework.TensorShapeProto{
85-
Dim: []*framework.TensorShapeProto_Dim{
86-
&framework.TensorShapeProto_Dim{
87-
Size: 3,
88-
},
89-
},
90-
},
91-
IntVal: []int32{1, 2, 3},
92-
},
93-
"features": &framework.TensorProto{
94-
Dtype: framework.DataType_DT_FLOAT,
95-
TensorShape: &framework.TensorShapeProto{
96-
Dim: []*framework.TensorShapeProto_Dim{
97-
&framework.TensorShapeProto_Dim{
98-
Size: 3,
99-
},
100-
&framework.TensorShapeProto_Dim{
101-
Size: 9,
102-
},
103-
},
104-
},
105-
FloatVal: []float32{
106-
1, 2, 3, 4, 5, 6, 7, 8, 9,
107-
1, 2, 3, 4, 5, 6, 7, 8, 9,
108-
1, 2, 3, 4, 5, 6, 7, 8, 9,
109-
},
110-
},
111-
},
112-
}
73+
pr := newPredictRequest(*modelName, *modelVersion)
74+
addInput(pr, "keys", framework.DataType_DT_INT32, []int32{1, 2, 3}, nil, nil)
75+
addInput(pr, "features", framework.DataType_DT_FLOAT, []float32{
76+
1, 2, 3, 4, 5, 6, 7, 8, 9,
77+
1, 2, 3, 4, 5, 6, 7, 8, 9,
78+
1, 2, 3, 4, 5, 6, 7, 8, 9,
79+
}, []int64{3, 9}, nil)
80+
return pr
11381
}
11482

11583
// Example data:
11684
// 0 5:1 6:1 17:1 21:1 35:1 40:1 53:1 63:1 71:1 73:1 74:1 76:1 80:1 83:1
11785
// 1 5:1 7:1 17:1 22:1 36:1 40:1 51:1 63:1 67:1 73:1 74:1 76:1 81:1 83:1
11886
func newSparsePredictRequest(modelName *string, modelVersion *int64) *pb.PredictRequest {
119-
return &pb.PredictRequest{
120-
ModelSpec: &pb.ModelSpec{
121-
Name: *modelName,
122-
Version: &google_protobuf.Int64Value{
123-
Value: *modelVersion,
124-
},
125-
},
126-
Inputs: map[string]*framework.TensorProto{
127-
"keys": &framework.TensorProto{
128-
Dtype: framework.DataType_DT_INT32,
129-
TensorShape: &framework.TensorShapeProto{
130-
Dim: []*framework.TensorShapeProto_Dim{
131-
&framework.TensorShapeProto_Dim{
132-
Size: 2,
133-
},
134-
},
135-
},
136-
IntVal: []int32{1, 2},
137-
},
138-
"indexs": &framework.TensorProto{
139-
Dtype: framework.DataType_DT_INT64,
140-
TensorShape: &framework.TensorShapeProto{
141-
Dim: []*framework.TensorShapeProto_Dim{
142-
&framework.TensorShapeProto_Dim{
143-
Size: 28,
144-
},
145-
&framework.TensorShapeProto_Dim{
146-
Size: 2,
147-
},
148-
},
149-
},
150-
Int64Val: []int64{
151-
0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5,
152-
0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11,
153-
0, 12, 0, 13, 1, 0, 1, 1, 1, 2, 1, 3,
154-
1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1, 9,
155-
1, 10, 1, 11, 1, 12, 1, 13,
156-
},
157-
},
158-
"ids": &framework.TensorProto{
159-
Dtype: framework.DataType_DT_INT64,
160-
TensorShape: &framework.TensorShapeProto{
161-
Dim: []*framework.TensorShapeProto_Dim{
162-
&framework.TensorShapeProto_Dim{
163-
Size: 28,
164-
},
165-
},
166-
},
167-
Int64Val: []int64{
168-
5, 6, 17, 21, 35, 40, 53, 63, 71, 73, 74, 76, 80, 83,
169-
5, 7, 17, 22, 36, 40, 51, 63, 67, 73, 74, 76, 81, 83,
170-
},
171-
},
172-
"values": &framework.TensorProto{
173-
Dtype: framework.DataType_DT_FLOAT,
174-
TensorShape: &framework.TensorShapeProto{
175-
Dim: []*framework.TensorShapeProto_Dim{
176-
&framework.TensorShapeProto_Dim{
177-
Size: 28,
178-
},
179-
},
180-
},
181-
FloatVal: []float32{
182-
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
183-
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
184-
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
185-
},
186-
},
187-
"shape": &framework.TensorProto{
188-
Dtype: framework.DataType_DT_INT64,
189-
TensorShape: &framework.TensorShapeProto{
190-
Dim: []*framework.TensorShapeProto_Dim{
191-
&framework.TensorShapeProto_Dim{
192-
Size: 2,
193-
},
194-
},
195-
},
196-
Int64Val: []int64{
197-
2, 124,
198-
},
199-
},
200-
},
201-
}
87+
pr := newPredictRequest(*modelName, *modelVersion)
88+
addInput(pr, "keys", framework.DataType_DT_INT32, []int32{1, 2}, nil, nil)
89+
addInput(pr, "indexs", framework.DataType_DT_INT64, []int64{
90+
0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5,
91+
0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11,
92+
0, 12, 0, 13, 1, 0, 1, 1, 1, 2, 1, 3,
93+
1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1, 9,
94+
1, 10, 1, 11, 1, 12, 1, 13,
95+
}, []int64{28, 2}, nil)
96+
addInput(pr, "ids", framework.DataType_DT_INT64, []int64{
97+
5, 6, 17, 21, 35, 40, 53, 63, 71, 73, 74, 76, 80, 83,
98+
5, 7, 17, 22, 36, 40, 51, 63, 67, 73, 74, 76, 81, 83,
99+
}, nil, nil)
100+
addInput(pr, "values", framework.DataType_DT_FLOAT, []float32{
101+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
102+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
103+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
104+
}, nil, nil)
105+
addInput(pr, "shape", framework.DataType_DT_INT64, []int64{2, 124}, nil, nil)
106+
return pr
202107
}

golang_predict_client/src/util.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package main
2+
3+
import (
4+
"errors"
5+
"reflect"
6+
framework "tensorflow/core/framework"
7+
pb "tensorflow_serving"
8+
9+
google_protobuf "github.com/golang/protobuf/ptypes/wrappers"
10+
)
11+
12+
func newPredictRequest(modelName string, modelVersion int64) (pr *pb.PredictRequest) {
13+
return &pb.PredictRequest{
14+
ModelSpec: &pb.ModelSpec{
15+
Name: modelName,
16+
Version: &google_protobuf.Int64Value{
17+
Value: modelVersion,
18+
},
19+
},
20+
Inputs: make(map[string]*framework.TensorProto),
21+
}
22+
}
23+
24+
// if tensor is one dim, shapeSize is nil
25+
func addInput(pr *pb.PredictRequest, tensorName string, dataType framework.DataType, tensor interface{},
26+
shapeSize []int64, shapeName []string) (err error) {
27+
v := reflect.ValueOf(tensor)
28+
if v.Kind() != reflect.Slice {
29+
return errors.New("tensor must be slice")
30+
}
31+
size := v.Len()
32+
tp := &framework.TensorProto{
33+
Dtype: dataType,
34+
}
35+
36+
var ok bool
37+
switch dataType {
38+
case framework.DataType_DT_HALF:
39+
tp.HalfVal, ok = tensor.([]int32)
40+
case framework.DataType_DT_FLOAT:
41+
tp.FloatVal, ok = tensor.([]float32)
42+
case framework.DataType_DT_DOUBLE:
43+
tp.DoubleVal, ok = tensor.([]float64)
44+
case framework.DataType_DT_INT16, framework.DataType_DT_INT32,
45+
framework.DataType_DT_INT8, framework.DataType_DT_UINT8:
46+
tp.IntVal, ok = tensor.([]int32)
47+
case framework.DataType_DT_STRING:
48+
tp.StringVal, ok = tensor.([][]byte)
49+
case framework.DataType_DT_COMPLEX64:
50+
tp.ScomplexVal, ok = tensor.([]float32)
51+
case framework.DataType_DT_INT64:
52+
tp.Int64Val, ok = tensor.([]int64)
53+
case framework.DataType_DT_BOOL:
54+
tp.BoolVal, ok = tensor.([]bool)
55+
case framework.DataType_DT_COMPLEX128:
56+
tp.DcomplexVal, ok = tensor.([]float64)
57+
case framework.DataType_DT_RESOURCE:
58+
tp.ResourceHandleVal, ok = tensor.([]*framework.ResourceHandle)
59+
default:
60+
err = errors.New("Unknown data type")
61+
}
62+
63+
if !ok {
64+
if err != nil {
65+
err = errors.New("Type switch failed")
66+
}
67+
return
68+
}
69+
70+
if shapeSize == nil {
71+
name := ""
72+
if len(shapeName) != 0 {
73+
name = shapeName[0]
74+
}
75+
tp.TensorShape = &framework.TensorShapeProto{
76+
Dim: []*framework.TensorShapeProto_Dim{
77+
&framework.TensorShapeProto_Dim{
78+
Size: int64(size),
79+
Name: name,
80+
},
81+
},
82+
}
83+
} else {
84+
if shapeName != nil && len(shapeName) != len(shapeSize) {
85+
return errors.New("shapeName and shapeSize have different size")
86+
}
87+
tp.TensorShape = &framework.TensorShapeProto{
88+
Dim: []*framework.TensorShapeProto_Dim{},
89+
}
90+
for i, size := range shapeSize {
91+
name := ""
92+
if shapeName != nil {
93+
name = shapeName[i]
94+
}
95+
tp.TensorShape.Dim = append(tp.TensorShape.Dim, &framework.TensorShapeProto_Dim{
96+
Size: size,
97+
Name: name,
98+
})
99+
}
100+
}
101+
pr.Inputs[tensorName] = tp
102+
return
103+
}

0 commit comments

Comments
 (0)