def baseline_forward(self, X, size, n_class):
shape = X.get_shape()
_X = tf.transpose(X, [1, 0, 2]) # batch_size x sentence_length x word_length -> batch_size x sentence_length x word_length
_X = tf.reshape(_X, [-1, int(shape[2])]) # (batch_size x sentence_length) x word_length
seq = tf.split(0, int(shape[1]), _X) # sentence_length x (batch_size x word_length)
with tf.name_scope("LSTM"):
lstm_cell = rnn_cell.BasicLSTMCell(size, forget_bias=1.0)
outputs, states = rnn.rnn(lstm_cell, seq, dtype=tf.float32)
with tf.name_scope("LSTM-Classifier"):
W = tf.Variable(tf.random_normal([size, n_class]), name="W")
b = tf.Variable(tf.random_normal([n_class]), name="b")
output = tf.matmul(outputs[-1], W) + b
return output
评论列表
文章目录