def encode(sess, memory, encoder, values, keys, full_batch_host, keys_host, batch_size):
full_batch_size = full_batch_host.shape[0]
assert full_batch_size >= batch_size, "full batch size needs to be >= mini-batch size"
memories_host = np.zeros([memory.num_models, memory.input_size])
print 'full_batch_size = ', full_batch_size, 'minibatch_size = ', batch_size
for begin,end in zip(range(0, full_batch_size, batch_size),
range(batch_size, full_batch_size+1, batch_size)):
feed_dict={keys: keys_host[begin:end],
values: full_batch_host[begin:end]}
# encode value with the keys
memories_host += sess.run(encoder, feed_dict=feed_dict)
#np.savetxt("encoded.csv", memories_host, delimiter=",")
return memories_host
run_mnist_example.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录