gated_rnn_test.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:sonnet 作者: deepmind 项目源码 文件源码
def testPeephole(self):
    batch_size = 5
    hidden_size = 20

    # Initialize the rnn and verify the number of parameter sets.
    inputs = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    prev_cell = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    prev_hidden = tf.placeholder(tf.float32, shape=[batch_size, hidden_size])
    lstm = snt.LSTM(hidden_size, use_peepholes=True)
    _, next_state = lstm(inputs, (prev_hidden, prev_cell))
    next_hidden, next_cell = next_state
    lstm_variables = lstm.get_variables()
    self.assertEqual(len(lstm_variables), 5, "LSTM should have 5 variables")

    # Unpack parameters into dict and check their sizes.
    param_map = {param.name.split("/")[-1].split(":")[0]:
                 param for param in lstm_variables}
    self.assertShapeEqual(np.ndarray(4 * hidden_size),
                          param_map[snt.LSTM.B_GATES].initial_value)
    self.assertShapeEqual(np.ndarray((2 * hidden_size, 4 * hidden_size)),
                          param_map[snt.LSTM.W_GATES].initial_value)
    self.assertShapeEqual(np.ndarray(hidden_size),
                          param_map[snt.LSTM.W_F_DIAG].initial_value)
    self.assertShapeEqual(np.ndarray(hidden_size),
                          param_map[snt.LSTM.W_I_DIAG].initial_value)
    self.assertShapeEqual(np.ndarray(hidden_size),
                          param_map[snt.LSTM.W_O_DIAG].initial_value)

    # With random data, check the TF calculation matches the Numpy version.
    input_data = np.random.randn(batch_size, hidden_size)
    prev_hidden_data = np.random.randn(batch_size, hidden_size)
    prev_cell_data = np.random.randn(batch_size, hidden_size)

    with self.test_session() as session:
      tf.global_variables_initializer().run()
      fetches = [(next_hidden, next_cell),
                 param_map[snt.LSTM.W_GATES],
                 param_map[snt.LSTM.B_GATES],
                 param_map[snt.LSTM.W_F_DIAG],
                 param_map[snt.LSTM.W_I_DIAG],
                 param_map[snt.LSTM.W_O_DIAG]]
      output = session.run(fetches,
                           {inputs: input_data,
                            prev_cell: prev_cell_data,
                            prev_hidden: prev_hidden_data})

    next_state_ex, w_ex, b_ex, wfd_ex, wid_ex, wod_ex = output
    in_and_hid = np.concatenate((input_data, prev_hidden_data), axis=1)
    real_gate = np.dot(in_and_hid, w_ex) + b_ex
    # i = input_gate, j = next_input, f = forget_gate, o = output_gate
    i, j, f, o = np.hsplit(real_gate, 4)
    real_cell = (prev_cell_data /
                 (1 + np.exp(-(f + lstm._forget_bias +
                               wfd_ex * prev_cell_data))) +
                 1 / (1 + np.exp(-(i + wid_ex * prev_cell_data))) * np.tanh(j))
    real_hidden = (np.tanh(real_cell + wod_ex * real_cell) *
                   1 / (1 + np.exp(-o)))

    self.assertAllClose(real_hidden, next_state_ex[0])
    self.assertAllClose(real_cell, next_state_ex[1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号