doc2vec_train_doc_prediction.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:kaggle_redefining_cancer_treatment 作者: jorgemf 项目源码 文件源码
def doc2vec_prediction_model(input_vectors, input_gene, input_variation, output_label, batch_size,
                             is_training, embedding_size, output_classes):
    # inputs/outputs
    input_vectors = tf.reshape(input_vectors, [batch_size, embedding_size])
    input_gene = tf.reshape(input_gene, [batch_size, embedding_size])
    input_variation = tf.reshape(input_variation, [batch_size, embedding_size])
    targets = None
    if output_label is not None:
        output_label = tf.reshape(output_label, [batch_size, 1])
        targets = tf.one_hot(output_label, axis=-1, depth=output_classes, on_value=1.0,
                             off_value=0.0)
        targets = tf.squeeze(targets, axis=1)

    net = tf.concat([input_vectors, input_gene, input_variation], axis=1)
    net = layers.fully_connected(net, embedding_size * 2, activation_fn=tf.nn.relu)
    net = layers.dropout(net, keep_prob=0.85, is_training=is_training)
    net = layers.fully_connected(net, embedding_size, activation_fn=tf.nn.relu)
    net = layers.dropout(net, keep_prob=0.85, is_training=is_training)
    net = layers.fully_connected(net, embedding_size // 4, activation_fn=tf.nn.relu)
    logits = layers.fully_connected(net, output_classes, activation_fn=None)

    return logits, targets
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号