test_cosine.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:tensorflow-phased-lstm 作者: philipperemy 项目源码 文件源码
def RNN(_X, _weights, _biases, lens):
    if FLAGS.unit == 'PLSTM':
        cell = PhasedLSTMCell(FLAGS.n_hidden, use_peepholes=True)
    elif FLAGS.unit == 'GRU':
        cell = GRUCell(FLAGS.n_hidden)
    elif FLAGS.unit == 'LSTM':
        cell = LSTMCell(FLAGS.n_hidden, use_peepholes=True)
    else:
        raise ValueError('Unit {} not implemented.'.format(FLAGS.unit))

    outputs, states = tf.nn.dynamic_rnn(cell, _X, dtype=tf.float32, sequence_length=lens)

    # TODO better (?) in lack of smart indexing
    batch_size = tf.shape(outputs)[0]
    max_len = tf.shape(outputs)[1]
    out_size = int(outputs.get_shape()[2])
    index = tf.range(0, batch_size) * max_len + (lens - 1)
    flat = tf.reshape(outputs, [-1, out_size])
    relevant = tf.gather(flat, index)

    return tf.nn.bias_add(tf.matmul(relevant, _weights['out']), _biases['out'])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号