models.py 文件源码

python
阅读 15 收藏 0 点赞 0 评论 0

项目:feudal_networks 作者: dmakian 项目源码 文件源码
def build_lstm(x, size, name, step_size):
    lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)

    c_init = np.zeros((1, lstm.state_size.c), np.float32)
    h_init = np.zeros((1, lstm.state_size.h), np.float32)
    state_init = [c_init, h_init]

    c_in = tf.placeholder(tf.float32, 
            shape=[1, lstm.state_size.c],
            name='c_in')
    h_in = tf.placeholder(tf.float32, 
            shape=[1, lstm.state_size.h],
            name='h_in')
    state_in = [c_in, h_in]

    state_in = rnn.LSTMStateTuple(c_in, h_in)

    lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
        lstm, x, initial_state=state_in, sequence_length=step_size,
        time_major=False)
    lstm_outputs = tf.reshape(lstm_outputs, [-1, size])

    lstm_c, lstm_h = lstm_state
    state_out = [lstm_c[:1, :], lstm_h[:1, :]]
    return lstm_outputs, state_init, state_in, state_out
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号