def doc2vec_prediction_model(input_vectors, input_gene, input_variation, output_label, batch_size,
is_training, embedding_size, output_classes):
# inputs/outputs
input_vectors = tf.reshape(input_vectors, [batch_size, embedding_size])
input_gene = tf.reshape(input_gene, [batch_size, embedding_size])
input_variation = tf.reshape(input_variation, [batch_size, embedding_size])
targets = None
if output_label is not None:
output_label = tf.reshape(output_label, [batch_size, 1])
targets = tf.one_hot(output_label, axis=-1, depth=output_classes, on_value=1.0,
off_value=0.0)
targets = tf.squeeze(targets, axis=1)
net = tf.concat([input_vectors, input_gene, input_variation], axis=1)
net = layers.fully_connected(net, embedding_size * 2, activation_fn=tf.nn.relu)
net = layers.dropout(net, keep_prob=0.85, is_training=is_training)
net = layers.fully_connected(net, embedding_size, activation_fn=tf.nn.relu)
net = layers.dropout(net, keep_prob=0.85, is_training=is_training)
net = layers.fully_connected(net, embedding_size // 4, activation_fn=tf.nn.relu)
logits = layers.fully_connected(net, output_classes, activation_fn=None)
return logits, targets
doc2vec_train_doc_prediction.py 文件源码
python
阅读 28
收藏 0
点赞 0
评论 0
评论列表
文章目录