lstm_model.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:dstc6_dialogue_breakdown_task 作者: JudeLee19 项目源码 文件源码
def add_logits_op(self):
        with tf.variable_scope('lstm'):
            W_i = tf.get_variable('W_i', [self.input_size, self.num_hidden], initializer=xav())
            b_i = tf.get_variable('b_i', [self.num_hidden], initializer=tf.constant_initializer(0.))

            reshaped_features = tf.transpose(self.input_features, [1, 0, 2])
            print('reshaped_features: ', reshaped_features.shape)
            reshaped_features = tf.reshape(reshaped_features, [-1, self.input_size])

            proj_input_features = tf.matmul(reshaped_features, W_i) + b_i

            proj_input_features = tf.split(proj_input_features, 10, 0)

            # define lstm cell
            lstm_fw = tf.contrib.rnn.LSTMCell(self.num_hidden, state_is_tuple=True)

            outputs, final_state = tf.contrib.rnn.static_rnn(lstm_fw, inputs=proj_input_features, dtype=tf.float32)

            outputs = tf.transpose(outputs, [1, 0, 2])
            outputs = tf.reshape(outputs, [-1, self.num_hidden])

        with tf.variable_scope('output_projection'):
            W_o = tf.get_variable('Wo', [self.num_hidden, self.num_classes],
                                 initializer=xav())
            b_o = tf.get_variable('bo', [self.num_classes],
                                 initializer=tf.constant_initializer(0.))

            self.logits = tf.matmul(outputs, W_o) + b_o
            self.logits = tf.expand_dims(self.logits, 0)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号