text_classification_model_simple_bidirectional.py 文件源码

python
阅读 39 收藏 0 点赞 0 评论 0

项目:kaggle_redefining_cancer_treatment 作者: jorgemf 项目源码 文件源码
def rnn(self, sequence, sequence_length, max_length, dropout, batch_size, training,
            num_hidden=TC_MODEL_HIDDEN, num_layers=TC_MODEL_LAYERS):

        # Recurrent network.
        cell_fw = tf.nn.rnn_cell.GRUCell(num_hidden)
        cell_bw = tf.nn.rnn_cell.GRUCell(num_hidden)
        type = sequence.dtype
        (fw_outputs, bw_outputs), _ = \
            tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw,
                                            cell_bw=cell_bw,
                                            initial_state_fw=cell_fw.zero_state(batch_size, type),
                                            initial_state_bw=cell_bw.zero_state(batch_size, type),
                                            inputs=sequence,
                                            dtype=tf.float32,
                                            swap_memory=True,
                                            sequence_length=sequence_length)
        sequence_output = tf.concat((fw_outputs, bw_outputs), 2)
        # get last output of the dynamic_rnn
        sequence_output = tf.reshape(sequence_output, [batch_size * max_length, num_hidden * 2])
        indexes = tf.range(batch_size) * max_length + (sequence_length - 1)
        output = tf.gather(sequence_output, indexes)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号