def _random_weights(self, size=50, num_shards=1):
assert size > 0
assert num_shards > 0
assert num_shards <= size
embedding_weights = partitioned_variables.create_partitioned_variables(
shape=[size],
slicing=[num_shards],
initializer=init_ops.truncated_normal_initializer(
mean=0.0, stddev=1.0, dtype=dtypes.float32))
for w in embedding_weights:
w.initializer.run()
return embedding_weights
embedding_ops_test.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录