def build_infer_graph(x, batch_size, vocab_size=VOCAB_SIZE, embedding_size=32,
rnn_size=128, num_layers=2, p_keep=1.0):
"""
builds inference graph
"""
infer_args = {"batch_size": batch_size, "vocab_size": vocab_size,
"embedding_size": embedding_size, "rnn_size": rnn_size,
"num_layers": num_layers, "p_keep": p_keep}
logger.debug("building inference graph: %s.", infer_args)
# other placeholders
p_keep = tf.placeholder_with_default(p_keep, [], "p_keep")
batch_size = tf.placeholder_with_default(batch_size, [], "batch_size")
# embedding layer
embed_seq = layers.embed_sequence(x, vocab_size, embedding_size)
# shape: [batch_size, seq_len, embedding_size]
embed_seq = tf.nn.dropout(embed_seq, keep_prob=p_keep)
# shape: [batch_size, seq_len, embedding_size]
# RNN layers
cells = [rnn.LSTMCell(rnn_size) for _ in range(num_layers)]
cells = [rnn.DropoutWrapper(cell, output_keep_prob=p_keep) for cell in cells]
cells = rnn.MultiRNNCell(cells)
input_state = cells.zero_state(batch_size, tf.float32)
# shape: [num_layers, 2, batch_size, rnn_size]
rnn_out, output_state = tf.nn.dynamic_rnn(cells, embed_seq, initial_state=input_state)
# rnn_out shape: [batch_size, seq_len, rnn_size]
# output_state shape: [num_layers, 2, batch_size, rnn_size]
with tf.name_scope("lstm"):
tf.summary.histogram("outputs", rnn_out)
for c_state, h_state in output_state:
tf.summary.histogram("c_state", c_state)
tf.summary.histogram("h_state", h_state)
# fully connected layer
logits = layers.fully_connected(rnn_out, vocab_size, activation_fn=None)
# shape: [batch_size, seq_len, vocab_size]
# predictions
with tf.name_scope("softmax"):
probs = tf.nn.softmax(logits)
# shape: [batch_size, seq_len, vocab_size]
with tf.name_scope("sequence"):
tf.summary.histogram("embeddings", embed_seq)
tf.summary.histogram("logits", logits)
model = {"logits": logits, "probs": probs,
"input_state": input_state, "output_state": output_state,
"p_keep": p_keep, "batch_size": batch_size, "infer_args": infer_args}
return model
评论列表
文章目录