|
8 | 8 | framework "tensorflow/core/framework" |
9 | 9 | pb "tensorflow_serving" |
10 | 10 |
|
11 | | - google_protobuf "github.com/golang/protobuf/ptypes/wrappers" |
12 | 11 | "golang.org/x/net/context" |
13 | 12 | "google.golang.org/grpc" |
14 | 13 | "google.golang.org/grpc/credentials" |
@@ -71,132 +70,38 @@ func main() { |
71 | 70 | } |
72 | 71 |
|
73 | 72 | func newDensePredictRequest(modelName *string, modelVersion *int64) *pb.PredictRequest { |
74 | | - return &pb.PredictRequest{ |
75 | | - ModelSpec: &pb.ModelSpec{ |
76 | | - Name: *modelName, |
77 | | - Version: &google_protobuf.Int64Value{ |
78 | | - Value: *modelVersion, |
79 | | - }, |
80 | | - }, |
81 | | - Inputs: map[string]*framework.TensorProto{ |
82 | | - "keys": &framework.TensorProto{ |
83 | | - Dtype: framework.DataType_DT_INT32, |
84 | | - TensorShape: &framework.TensorShapeProto{ |
85 | | - Dim: []*framework.TensorShapeProto_Dim{ |
86 | | - &framework.TensorShapeProto_Dim{ |
87 | | - Size: 3, |
88 | | - }, |
89 | | - }, |
90 | | - }, |
91 | | - IntVal: []int32{1, 2, 3}, |
92 | | - }, |
93 | | - "features": &framework.TensorProto{ |
94 | | - Dtype: framework.DataType_DT_FLOAT, |
95 | | - TensorShape: &framework.TensorShapeProto{ |
96 | | - Dim: []*framework.TensorShapeProto_Dim{ |
97 | | - &framework.TensorShapeProto_Dim{ |
98 | | - Size: 3, |
99 | | - }, |
100 | | - &framework.TensorShapeProto_Dim{ |
101 | | - Size: 9, |
102 | | - }, |
103 | | - }, |
104 | | - }, |
105 | | - FloatVal: []float32{ |
106 | | - 1, 2, 3, 4, 5, 6, 7, 8, 9, |
107 | | - 1, 2, 3, 4, 5, 6, 7, 8, 9, |
108 | | - 1, 2, 3, 4, 5, 6, 7, 8, 9, |
109 | | - }, |
110 | | - }, |
111 | | - }, |
112 | | - } |
| 73 | + pr := newPredictRequest(*modelName, *modelVersion) |
| 74 | + addInput(pr, "keys", framework.DataType_DT_INT32, []int32{1, 2, 3}, nil, nil) |
| 75 | + addInput(pr, "features", framework.DataType_DT_FLOAT, []float32{ |
| 76 | + 1, 2, 3, 4, 5, 6, 7, 8, 9, |
| 77 | + 1, 2, 3, 4, 5, 6, 7, 8, 9, |
| 78 | + 1, 2, 3, 4, 5, 6, 7, 8, 9, |
| 79 | + }, []int64{3, 9}, nil) |
| 80 | + return pr |
113 | 81 | } |
114 | 82 |
|
115 | 83 | // Example data: |
116 | 84 | // 0 5:1 6:1 17:1 21:1 35:1 40:1 53:1 63:1 71:1 73:1 74:1 76:1 80:1 83:1 |
117 | 85 | // 1 5:1 7:1 17:1 22:1 36:1 40:1 51:1 63:1 67:1 73:1 74:1 76:1 81:1 83:1 |
118 | 86 | func newSparsePredictRequest(modelName *string, modelVersion *int64) *pb.PredictRequest { |
119 | | - return &pb.PredictRequest{ |
120 | | - ModelSpec: &pb.ModelSpec{ |
121 | | - Name: *modelName, |
122 | | - Version: &google_protobuf.Int64Value{ |
123 | | - Value: *modelVersion, |
124 | | - }, |
125 | | - }, |
126 | | - Inputs: map[string]*framework.TensorProto{ |
127 | | - "keys": &framework.TensorProto{ |
128 | | - Dtype: framework.DataType_DT_INT32, |
129 | | - TensorShape: &framework.TensorShapeProto{ |
130 | | - Dim: []*framework.TensorShapeProto_Dim{ |
131 | | - &framework.TensorShapeProto_Dim{ |
132 | | - Size: 2, |
133 | | - }, |
134 | | - }, |
135 | | - }, |
136 | | - IntVal: []int32{1, 2}, |
137 | | - }, |
138 | | - "indexs": &framework.TensorProto{ |
139 | | - Dtype: framework.DataType_DT_INT64, |
140 | | - TensorShape: &framework.TensorShapeProto{ |
141 | | - Dim: []*framework.TensorShapeProto_Dim{ |
142 | | - &framework.TensorShapeProto_Dim{ |
143 | | - Size: 28, |
144 | | - }, |
145 | | - &framework.TensorShapeProto_Dim{ |
146 | | - Size: 2, |
147 | | - }, |
148 | | - }, |
149 | | - }, |
150 | | - Int64Val: []int64{ |
151 | | - 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, |
152 | | - 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, |
153 | | - 0, 12, 0, 13, 1, 0, 1, 1, 1, 2, 1, 3, |
154 | | - 1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1, 9, |
155 | | - 1, 10, 1, 11, 1, 12, 1, 13, |
156 | | - }, |
157 | | - }, |
158 | | - "ids": &framework.TensorProto{ |
159 | | - Dtype: framework.DataType_DT_INT64, |
160 | | - TensorShape: &framework.TensorShapeProto{ |
161 | | - Dim: []*framework.TensorShapeProto_Dim{ |
162 | | - &framework.TensorShapeProto_Dim{ |
163 | | - Size: 28, |
164 | | - }, |
165 | | - }, |
166 | | - }, |
167 | | - Int64Val: []int64{ |
168 | | - 5, 6, 17, 21, 35, 40, 53, 63, 71, 73, 74, 76, 80, 83, |
169 | | - 5, 7, 17, 22, 36, 40, 51, 63, 67, 73, 74, 76, 81, 83, |
170 | | - }, |
171 | | - }, |
172 | | - "values": &framework.TensorProto{ |
173 | | - Dtype: framework.DataType_DT_FLOAT, |
174 | | - TensorShape: &framework.TensorShapeProto{ |
175 | | - Dim: []*framework.TensorShapeProto_Dim{ |
176 | | - &framework.TensorShapeProto_Dim{ |
177 | | - Size: 28, |
178 | | - }, |
179 | | - }, |
180 | | - }, |
181 | | - FloatVal: []float32{ |
182 | | - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
183 | | - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
184 | | - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
185 | | - }, |
186 | | - }, |
187 | | - "shape": &framework.TensorProto{ |
188 | | - Dtype: framework.DataType_DT_INT64, |
189 | | - TensorShape: &framework.TensorShapeProto{ |
190 | | - Dim: []*framework.TensorShapeProto_Dim{ |
191 | | - &framework.TensorShapeProto_Dim{ |
192 | | - Size: 2, |
193 | | - }, |
194 | | - }, |
195 | | - }, |
196 | | - Int64Val: []int64{ |
197 | | - 2, 124, |
198 | | - }, |
199 | | - }, |
200 | | - }, |
201 | | - } |
| 87 | + pr := newPredictRequest(*modelName, *modelVersion) |
| 88 | + addInput(pr, "keys", framework.DataType_DT_INT32, []int32{1, 2}, nil, nil) |
| 89 | + addInput(pr, "indexs", framework.DataType_DT_INT64, []int64{ |
| 90 | + 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, |
| 91 | + 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, |
| 92 | + 0, 12, 0, 13, 1, 0, 1, 1, 1, 2, 1, 3, |
| 93 | + 1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1, 9, |
| 94 | + 1, 10, 1, 11, 1, 12, 1, 13, |
| 95 | + }, []int64{28, 2}, nil) |
| 96 | + addInput(pr, "ids", framework.DataType_DT_INT64, []int64{ |
| 97 | + 5, 6, 17, 21, 35, 40, 53, 63, 71, 73, 74, 76, 80, 83, |
| 98 | + 5, 7, 17, 22, 36, 40, 51, 63, 67, 73, 74, 76, 81, 83, |
| 99 | + }, nil, nil) |
| 100 | + addInput(pr, "values", framework.DataType_DT_FLOAT, []float32{ |
| 101 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 102 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 103 | + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, |
| 104 | + }, nil, nil) |
| 105 | + addInput(pr, "shape", framework.DataType_DT_INT64, []int64{2, 124}, nil, nil) |
| 106 | + return pr |
202 | 107 | } |
0 commit comments