seq_batch.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def __init__(self, values, mask, name='SequenceBatch'):
        with tf.name_scope(name):
            # check that dimensions are correct
            values_shape = tf.shape(values)
            mask_shape = tf.shape(mask)
            values_shape_prefix = tf.slice(values_shape, [0], [2])
            max_rank = max(values.get_shape().ndims, mask.get_shape().ndims)

            assert_op = tf.assert_equal(values_shape_prefix, mask_shape,
                                        data=[values_shape_prefix, mask_shape], summarize=max_rank,
                                        name="assert_shape_prefix")

            with tf.control_dependencies([assert_op]):
                    self._values = tf.identity(values, name='values')
                    self._mask = tf.identity(mask, name='mask')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号