tf_model.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:char-rnn-text-generation 作者: yxtay 项目源码 文件源码
def build_infer_graph(x, batch_size, vocab_size=VOCAB_SIZE, embedding_size=32,
                      rnn_size=128, num_layers=2, p_keep=1.0):
    """
    builds inference graph
    """
    infer_args = {"batch_size": batch_size, "vocab_size": vocab_size,
                  "embedding_size": embedding_size, "rnn_size": rnn_size,
                  "num_layers": num_layers, "p_keep": p_keep}
    logger.debug("building inference graph: %s.", infer_args)

    # other placeholders
    p_keep = tf.placeholder_with_default(p_keep, [], "p_keep")
    batch_size = tf.placeholder_with_default(batch_size, [], "batch_size")

    # embedding layer
    embed_seq = layers.embed_sequence(x, vocab_size, embedding_size)
    # shape: [batch_size, seq_len, embedding_size]
    embed_seq = tf.nn.dropout(embed_seq, keep_prob=p_keep)
    # shape: [batch_size, seq_len, embedding_size]

    # RNN layers
    cells = [rnn.LSTMCell(rnn_size) for _ in range(num_layers)]
    cells = [rnn.DropoutWrapper(cell, output_keep_prob=p_keep) for cell in cells]
    cells = rnn.MultiRNNCell(cells)
    input_state = cells.zero_state(batch_size, tf.float32)
    # shape: [num_layers, 2, batch_size, rnn_size]
    rnn_out, output_state = tf.nn.dynamic_rnn(cells, embed_seq, initial_state=input_state)
    # rnn_out shape: [batch_size, seq_len, rnn_size]
    # output_state shape: [num_layers, 2, batch_size, rnn_size]
    with tf.name_scope("lstm"):
        tf.summary.histogram("outputs", rnn_out)
        for c_state, h_state in output_state:
            tf.summary.histogram("c_state", c_state)
            tf.summary.histogram("h_state", h_state)

    # fully connected layer
    logits = layers.fully_connected(rnn_out, vocab_size, activation_fn=None)
    # shape: [batch_size, seq_len, vocab_size]

    # predictions
    with tf.name_scope("softmax"):
        probs = tf.nn.softmax(logits)
        # shape: [batch_size, seq_len, vocab_size]

    with tf.name_scope("sequence"):
        tf.summary.histogram("embeddings", embed_seq)
        tf.summary.histogram("logits", logits)

    model = {"logits": logits, "probs": probs,
             "input_state": input_state, "output_state": output_state,
             "p_keep": p_keep, "batch_size": batch_size, "infer_args": infer_args}
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号