def create(predict_fn, word_representations,
batch_size, window_size, vocabulary_size,
result_callback):
assert result_callback is not None
instance_dtype = np.min_scalar_type(vocabulary_size - 1)
logging.info('Instance elements will be stored using %s.', instance_dtype)
if result_callback.should_average_input():
batcher = EmbeddingMapper(
predict_fn,
word_representations,
result_callback)
else:
batcher = WordBatcher(
predict_fn,
batch_size, window_size,
instance_dtype,
result_callback)
return batcher
评论列表
文章目录