def get_mask(self, max_frames, num_frames):
mask_array = []
for i in xrange(max_frames + 1):
tmp = [0.0] * max_frames
for j in xrange(i):
tmp[j] = 1.0
mask_array.append(tmp)
mask_array = np.array(mask_array)
mask_init = tf.constant_initializer(mask_array)
mask_emb = tf.get_variable("mask_emb", shape = [max_frames + 1, max_frames],
dtype = tf.float32, trainable = False, initializer = mask_init)
mask = tf.nn.embedding_lookup(mask_emb, num_frames)
return mask
distillchain_cnn_deep_combine_chain_model.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录