seq_batch.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:lang2program 作者: kelvinguu 项目源码 文件源码
def __init__(self, align='left', seq_length=None, dtype=tf.int32, name='FeedSequenceBatch'):
        """Create a Feedable SequenceBatch.

        Args:
            align (str): can be 'left' or 'right'. If 'left', values will be left-aligned, with padding on the right.
                If 'right', values will be right-aligned, with padding on the left. Default is 'left'.
            seq_length (int): the Tensor representing the SequenceBatch will have exactly this many columns. Default
                is None. If None, seq_length will be dynamically determined.
            dtype: data type of the SequenceBatch values array. Defaults to int32.
            name (str): namescope for the Tensors created inside this Model.
        """
        if align not in ('left', 'right'):
            raise ValueError("align must be either 'left' or 'right'.")
        self._align_right = (align == 'right')
        self._seq_length = seq_length

        with tf.name_scope(name):
            values = tf.placeholder(dtype, shape=[None, None], name='values')  # (batch_size, seq_length)
            mask = tf.placeholder(tf.float32, shape=[None, None], name='mask')  # (batch_size, seq_length)

        if self._seq_length is not None:
            # add static shape information
            batch_dim, _ = values.get_shape()
            new_shape = tf.TensorShape([batch_dim, tf.Dimension(seq_length)])
            values.set_shape(new_shape)
            mask.set_shape(new_shape)

        super(FeedSequenceBatch, self).__init__(values, mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号