rnn_model_no_state.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tensorflow_novelist-master 作者: charlesXu86 项目源码 文件源码
def rnn_model(self):
        cell = rnn.BasicLSTMCell(num_units=self.n_units)
        multi_cell = rnn.MultiRNNCell([cell]*self.n_layers)
        # we only need one output so get it wrapped to out one value which is next word index
        cell_wrapped = rnn.OutputProjectionWrapper(multi_cell, output_size=1)

        # get input embed
        embedding = tf.Variable(initial_value=tf.random_uniform([self.vocab_size, self.n_units], -1.0, 1.0))
        inputs = tf.nn.embedding_lookup(embedding, self.inputs)
        # what is inputs dim??

        outputs, states = tf.nn.dynamic_rnn(cell_wrapped, inputs=inputs, dtype=tf.float32)
        outputs = tf.reshape(outputs, [int(outputs.get_shape()[0]), int(inputs.get_shape()[1])])

        w = tf.Variable(tf.truncated_normal([int(inputs.get_shape()[1]), self.vocab_size]))
        b = tf.Variable(tf.zeros([self.vocab_size]))

        logits = tf.nn.bias_add(tf.matmul(outputs, w), b)
        return logits
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号