train.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:EKLAVYA 作者: shensq04 项目源码 文件源码
def probability(self):
        def lstm_cell():
            if 'reuse' in inspect.getargspec(tf.contrib.rnn.GRUCell.__init__).args:
                return tf.contrib.rnn.GRUCell(self.emb_dim, reuse=tf.get_variable_scope().reuse)
            else:
                return tf.contrib.rnn.GRUCell(self.emb_dim)

        attn_cell = lstm_cell
        if self.dropout < 1:
            def attn_cell():
                return tf.contrib.rnn.DropoutWrapper(
                    lstm_cell(), output_keep_prob=self._keep_prob)
        single_cell = tf.contrib.rnn.MultiRNNCell([attn_cell() for _ in range(self.num_layers)], state_is_tuple=True)

        output, state = tf.nn.dynamic_rnn(single_cell, self._data, dtype=tf.float32,
                                          sequence_length=self._length)
        weight = tf.Variable(tf.truncated_normal([self.emb_dim, self.num_classes], stddev=0.01))
        bias = tf.Variable(tf.constant(0.1, shape=[self.num_classes]))

        self.output = output
        probability = tf.matmul(self.last_relevant(output, self._length), weight) + bias
        return probability
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号