def run():
client = serving_grpc_client.GRPCClient('localhost:50051')
# ?? PTB ???
print("Loading ptb data...")
train_data, valid_data, test_data, _ = reader.ptb_raw_data(FLAGS.data_path)
# ?? PTB ???????? 10 ???????????????????? 11 ????
state = {}
logits = None
for i in range(10):
inputs = {
'input': tf.contrib.util.make_tensor_proto(test_data[i], shape=[1,1])
}
# ????????????????????
# ?????????????????????
if i > 0:
for key in state:
inputs[key] = tf.contrib.util.make_tensor_proto(state[key])
outputs = client.call_predict(inputs)
# ??????????? logits ????????????
for key in outputs:
if key == "logits":
logits = tf.contrib.util.make_ndarray(outputs[key])
else:
state[key] = tf.contrib.util.make_ndarray(outputs[key])
print('logits: {0}'.format(logits))
评论列表
文章目录