def create_generative(parameters):
print('Creating the neural network model.')
tf.reset_default_graph()
# tf Graph input
x = tf.placeholder(tf.float32, shape=(1, parameters['n_input']), name='input')
x = tf.verify_tensor_all_finite(x, "X not finite!")
y = tf.placeholder(tf.float32, shape=(1, parameters['n_output']), name='expected_output')
y = tf.verify_tensor_all_finite(y, "Y not finite!")
x = tf.Print(x, [x], "X: ")
y = tf.Print(y, [y], "Y: ")
lstm_state_size = np.sum(parameters['lstm_layers']) * 2
# Note: Batch size is the first dimension in istate.
istate = tf.placeholder(tf.float32, shape=(None, lstm_state_size), name='internal_state')
lr = tf.placeholder(tf.float32, name='learning_rate')
# The target to track itself and its peers, each with x, y ## and velocity x and y.
input_size = (parameters['n_peers'] + 1) * 2
inputToRnn = parameters['input_layer']
if (parameters['input_layer'] == None):
inputToRnn = parameters['n_input']
cells = [rnn_cell.LSTMCell(l, parameters['lstm_layers'][i-1] if (i > 0) else inputToRnn,
num_proj=parameters['lstm_layers'][i],
cell_clip=parameters['lstm_clip'],
use_peepholes=True) for i,l in enumerate(parameters['lstm_layers'])]
# TODO: GRUCell support here.
# cells = [rnn_cell.GRUCell(l, parameters['lstm_layers'][i-1] if (i > 0) else inputToRnn) for i,l in enumerate(parameters['lstm_layers'])]
model = {
'input_weights': tf.Variable(tf.random_normal(
[input_size, parameters['input_layer']]), name='input_weights'),
'input_bias': tf.Variable(tf.random_normal([parameters['input_layer']]), name='input_bias'),
'output_weights': tf.Variable(tf.random_normal([parameters['lstm_layers'][-1],
# 6 = 2 sigma, 2 mean, weight, rho
parameters['n_mixtures'] * 6]),
name='output_weights'),
# We need to put at least the standard deviation output biases to about 5 to prevent zeros and infinities.
# , mean = 5.0, stddev = 3.0
'output_bias': tf.Variable(tf.random_normal([parameters['n_mixtures'] * 6]),
name='output_bias'),
'rnn_cell': rnn_cell.MultiRNNCell(cells),
'lr': lr,
'x': x,
'y': y,
'keep_prob': tf.placeholder(tf.float32),
'istate': istate
}
# The next variables need to be remapped, because we don't have RNN context anymore:
# RNN/MultiRNNCell/Cell0/LSTMCell/ -> MultiRNNCell/Cell0/LSTMCell/
# B, W_F_diag, W_O_diag, W_I_diag, W_0
with tf.variable_scope("RNN"):
pred = RNN_generative(parameters, x, model, istate)
model['pred'] = pred[0]
model['last_state'] = pred[1]
return model
评论列表
文章目录