lstm.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:tensorflow-litterbox 作者: rwightman 项目源码 文件源码
def lstm(
        inputs,
        num_units,
        num_layers=1,
        initializer_fn=tf.truncated_normal,
        initializer_params=_default_initializer_params,
        dtype=tf.float32,
        scope=None
):
    print('input shape', inputs.get_shape())
    shape = inputs.get_shape().as_list()
    batch_size = shape[0]
    inputs_unpacked = tf.unpack(inputs, axis=1)

    cell = tf.contrib.rnn.python.ops.lstm_ops.LSTMBlockCell(num_units=num_units)
    print('cell state size', cell.state_size)

    if num_layers > 1:
        cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers)

    initializer_params = initializer_params or {}
    initializer_params['dtype'] = dtype
    if isinstance(cell.state_size, tuple):
        initial_state = tuple(initializer_fn([batch_size, s]) for s in cell.state_size)
    else:
        initial_state = initializer_fn(shape=[batch_size, cell.state_size], **initializer_params)

    outputs, _, _ = tf.nn.rnn(
        cell,
        inputs_unpacked,
        initial_state=initial_state,
        dtype=dtype,
        scope=scope)

    outputs = tf.pack(outputs, axis=1)
    print('output shape', outputs.get_shape())

    return outputs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号