def contextLSTM(self, c, l, c_real_len, reuse = False, scope = "ContextLSTM"):
def sentenceLSTM(s,
s_real_len,
reuse = reuse,
scope = "sentenceLSTM"):
"""
embedding sentence
Arguments
s: sentence (word index list), shape = [batch_size*20, 12]
s_real_len: length of the sentence before zero padding, int32
Returns
embedded_s: embedded sentence, shape = [batch_size*20, 32]
"""
embedded_sentence_word = tf.nn.embedding_lookup(self.c_word_embed_matrix, s)
s_input = tf.unstack(embedded_sentence_word, num = self.s_max_len, axis = 1)
lstm_cell = rnn.BasicLSTMCell(self.s_hidden, reuse = reuse)
outputs, _ = rnn.static_rnn(lstm_cell, s_input, dtype = tf.float32, scope = scope)
outputs = tf.stack(outputs)
outputs = tf.transpose(outputs, [1,0,2])
index = tf.range(0, self.batch_size* self.c_max_len) * (self.s_max_len) + (s_real_len - 1)
outputs = tf.gather(tf.reshape(outputs, [-1, self.s_hidden]), index)
return outputs
"""
Args
c: list of sentences, shape = [batch_size, 20, 12]
l: list of labels, shape = [batch_size, 20, 20]
c_real_len: list of real length, shape = [batch_size, 20]
Returns
tagged_c_objects: list of embedded sentence + label, shape = [batch_size, 52] 20?
len(tagged_c_objects) = 20
"""
sentences = tf.reshape(c, shape = [-1, self.s_max_len])
real_lens = tf.reshape(c_real_len, shape= [-1])
labels = tf.reshape(l, shape = [-1, self.c_max_len])
s_embedded = sentenceLSTM(sentences, real_lens, reuse = reuse)
c_embedded = tf.concat([s_embedded, labels], axis=1)
c_embedded = tf.reshape(c_embedded, shape = [self.batch_size, self.c_max_len, self.c_max_len + self.c_word_embed])
tagged_c_objects = tf.unstack(c_embedded, axis=1)
return tagged_c_objects
评论列表
文章目录