def rnn_model(self):
cell = rnn.BasicLSTMCell(num_units=self.n_units)
multi_cell = rnn.MultiRNNCell([cell]*self.n_layers)
# we only need one output so get it wrapped to out one value which is next word index
cell_wrapped = rnn.OutputProjectionWrapper(multi_cell, output_size=1)
# get input embed
embedding = tf.Variable(initial_value=tf.random_uniform([self.vocab_size, self.n_units], -1.0, 1.0))
inputs = tf.nn.embedding_lookup(embedding, self.inputs)
# what is inputs dim??
outputs, states = tf.nn.dynamic_rnn(cell_wrapped, inputs=inputs, dtype=tf.float32)
outputs = tf.reshape(outputs, [int(outputs.get_shape()[0]), int(inputs.get_shape()[1])])
w = tf.Variable(tf.truncated_normal([int(inputs.get_shape()[1]), self.vocab_size]))
b = tf.Variable(tf.zeros([self.vocab_size]))
logits = tf.nn.bias_add(tf.matmul(outputs, w), b)
return logits
rnn_model_no_state.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录