def char_rnn_model(features, target):
"""Character level recurrent neural network model to predict classes."""
target = tf.one_hot(target, 15, 1, 0)
byte_list = tf.one_hot(features, 256, 1, 0)
byte_list = tf.unstack(byte_list, axis=1)
cell = tf.contrib.rnn.GRUCell(HIDDEN_SIZE)
_, encoding = tf.contrib.rnn.static_rnn(cell, byte_list, dtype=tf.float32)
logits = tf.contrib.layers.fully_connected(encoding, 15, activation_fn=None)
loss = tf.contrib.losses.softmax_cross_entropy(logits, target)
train_op = tf.contrib.layers.optimize_loss(
loss,
tf.contrib.framework.get_global_step(),
optimizer='Adam',
learning_rate=0.01)
return ({
'class': tf.argmax(logits, 1),
'prob': tf.nn.softmax(logits)
}, loss, train_op)
评论列表
文章目录