text_classification_model_simple.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:kaggle_redefining_cancer_treatment 作者: jorgemf 项目源码 文件源码
def rnn(self, sequence, sequence_length, max_length, dropout, batch_size, training,
            num_hidden=TC_MODEL_HIDDEN, num_layers=TC_MODEL_LAYERS):
        # Recurrent network.
        cells = []
        for _ in range(num_layers):
            cell = tf.nn.rnn_cell.GRUCell(num_hidden)
            if training:
                cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=dropout)
            cells.append(cell)
        network = tf.nn.rnn_cell.MultiRNNCell(cells)
        type = sequence.dtype

        sequence_output, _ = tf.nn.dynamic_rnn(network, sequence, dtype=tf.float32,
                                               sequence_length=sequence_length,
                                               initial_state=network.zero_state(batch_size, type))
        # get last output of the dynamic_rnn
        sequence_output = tf.reshape(sequence_output, [batch_size * max_length, num_hidden])
        indexes = tf.range(batch_size) * max_length + (sequence_length - 1)
        output = tf.gather(sequence_output, indexes)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号