rnn_segment_completed.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:tf-text-workshop 作者: tf-dl-workshop 项目源码 文件源码
def rnn_segment(features, targets, mode, params):
    seq_feature = features['seq_feature']
    seq_length = features['seq_length']
    with tf.variable_scope("emb"):
        embeddings = tf.get_variable("char_emb", shape=[params['num_char'], params['emb_size']])
    seq_emb = tf.nn.embedding_lookup(embeddings, seq_feature)
    batch_size = tf.shape(seq_feature)[0]
    time_step = tf.shape(seq_feature)[1]
    flat_seq_emb = tf.reshape(seq_emb, shape=[batch_size, time_step, (params['k'] + 1) * params['emb_size']])
    cell = rnn.LSTMCell(params['rnn_units'])
    if mode == ModeKeys.TRAIN:
        cell = rnn.DropoutWrapper(cell, params['input_keep_prob'], params['output_keep_prob'])
    projection_cell = rnn.OutputProjectionWrapper(cell, params['num_class'])
    logits, _ = tf.nn.dynamic_rnn(projection_cell, flat_seq_emb, sequence_length=seq_length, dtype=tf.float32)
    weight_mask = tf.to_float(tf.sequence_mask(seq_length))
    loss = seq2seq.sequence_loss(logits, targets, weights=weight_mask)
    train_op = layers.optimize_loss(
        loss=loss,
        global_step=tf.contrib.framework.get_global_step(),
        learning_rate=params["learning_rate"],
        optimizer=tf.train.AdamOptimizer,
        clip_gradients=params['grad_clip'],
        summaries=[
            "learning_rate",
            "loss",
            "gradients",
            "gradient_norm",
        ])
    pred_classes = tf.to_int32(tf.argmax(input=logits, axis=2))
    pred_words = tf.logical_or(tf.equal(pred_classes, 0), tf.equal(pred_classes, 3))
    target_words = tf.logical_or(tf.equal(targets, 0), tf.equal(targets, 3))
    precision = metrics.streaming_precision(pred_words, target_words, weights=weight_mask)
    recall = metrics.streaming_recall(pred_words, target_words, weights=weight_mask)
    predictions = {
        "classes": pred_classes
    }
    eval_metric_ops = {
        "precision": precision,
        "recall": recall
    }
    return learn.ModelFnOps(mode, predictions, loss, train_op, eval_metric_ops=eval_metric_ops)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号