lstm_util.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:tensorlm 作者: batzner 项目源码 文件源码
def get_state_variables_for_batch(state_variables, batch_size):
    """Returns a subset of the state variables.

    This function takes the state variables returned by get_state_variables() and returns a subset
    for an actual forward-propagation run. Specifically, it clips each of the state variables to
    the given batch size.

    Before this call, each variable's first dimension has length
    max_batch_size but when the input has a lower batch size, the LSTM should also only update the
    state variables for the used batches.

    See get_state_variables() for more info.

    Args:
        state_variables (tuple[tf.contrib.rnn.LSTMStateTuple]): The LSTM's state variables.
        batch_size (tf.Tensor): An 0-dimensional tensor containing the batch size tensor in the
            computational graph.

    Returns:
        tuple[tf.contrib.rnn.LSTMStateTuple]: A new tuple of state variables clipped to the given
            batch size.
    """

    # Return a tuple of LSTMStateTuples but with only the first batch_size rows for each variable
    # in the tuples.
    result = []
    for state_c, state_h in state_variables:
        result.append(tf.contrib.rnn.LSTMStateTuple(state_c[:batch_size], state_h[:batch_size]))
    return tuple(result)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号