def create_rnn(config, x, scope='rnn'):
with tf.variable_scope(scope):
memory=config['rnn_size']
cell = rnn_cell.BasicLSTMCell(memory)
state = cell.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
x, state = rnn.rnn(cell, [tf.cast(x,tf.float32)], initial_state=state, dtype=tf.float32)
x = x[-1]
#w = tf.get_variable('w', [hc.get('rnn_size'),4])
#b = tf.get_variable('b', [4])
#x = tf.nn.xw_plus_b(x, w, b)
x=tf.sign(x)
return x, state
# Each step of the graph we have:
# x is [BATCH_SIZE, 4] where the data is an one hot binary vector of the form:
# [start_token end_token a b]
#
# y is [BATCH_SIZE, 4] is a binary vector of the chance each character is correct
#
评论列表
文章目录