test_tf_qrnn_work.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:tensorflow_qrnn 作者: icoxfog417 项目源码 文件源码
def baseline_forward(self, X, size, n_class):
        shape = X.get_shape()
        _X = tf.transpose(X, [1, 0, 2])  # batch_size x sentence_length x word_length -> batch_size x sentence_length x word_length
        _X = tf.reshape(_X, [-1, int(shape[2])])  # (batch_size x sentence_length) x word_length
        seq = tf.split(0, int(shape[1]), _X)  # sentence_length x (batch_size x word_length)

        with tf.name_scope("LSTM"):
            lstm_cell = rnn_cell.BasicLSTMCell(size, forget_bias=1.0)
            outputs, states = rnn.rnn(lstm_cell, seq, dtype=tf.float32)

        with tf.name_scope("LSTM-Classifier"):
            W = tf.Variable(tf.random_normal([size, n_class]), name="W")
            b = tf.Variable(tf.random_normal([n_class]), name="b")
            output = tf.matmul(outputs[-1], W) + b

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号