model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:location_tracking_ml 作者: cybercom-finland 项目源码 文件源码
def create_generative(parameters):
    print('Creating the neural network model.')

    tf.reset_default_graph()
    # tf Graph input
    x = tf.placeholder(tf.float32, shape=(1, parameters['n_input']), name='input')
    x = tf.verify_tensor_all_finite(x, "X not finite!")
    y = tf.placeholder(tf.float32, shape=(1, parameters['n_output']), name='expected_output')
    y = tf.verify_tensor_all_finite(y, "Y not finite!")
    x = tf.Print(x, [x], "X: ")
    y = tf.Print(y, [y], "Y: ")
    lstm_state_size = np.sum(parameters['lstm_layers']) * 2
    # Note: Batch size is the first dimension in istate.
    istate = tf.placeholder(tf.float32, shape=(None, lstm_state_size), name='internal_state')
    lr = tf.placeholder(tf.float32, name='learning_rate')

    # The target to track itself and its peers, each with x, y ## and velocity x and y.
    input_size = (parameters['n_peers'] + 1) * 2
    inputToRnn = parameters['input_layer']
    if (parameters['input_layer'] == None):
        inputToRnn = parameters['n_input']

    cells = [rnn_cell.LSTMCell(l, parameters['lstm_layers'][i-1] if (i > 0) else inputToRnn,
                               num_proj=parameters['lstm_layers'][i],
                               cell_clip=parameters['lstm_clip'],
                               use_peepholes=True) for i,l in enumerate(parameters['lstm_layers'])] 
    # TODO: GRUCell support here.
    # cells = [rnn_cell.GRUCell(l, parameters['lstm_layers'][i-1] if (i > 0) else inputToRnn) for i,l in enumerate(parameters['lstm_layers'])]
    model = {
        'input_weights': tf.Variable(tf.random_normal(
            [input_size, parameters['input_layer']]), name='input_weights'),
        'input_bias': tf.Variable(tf.random_normal([parameters['input_layer']]), name='input_bias'),
        'output_weights': tf.Variable(tf.random_normal([parameters['lstm_layers'][-1],
                                                        # 6 = 2 sigma, 2 mean, weight, rho
                                                        parameters['n_mixtures'] * 6]),
                                      name='output_weights'),
        # We need to put at least the standard deviation output biases to about 5 to prevent zeros and infinities.
        # , mean = 5.0, stddev = 3.0
        'output_bias': tf.Variable(tf.random_normal([parameters['n_mixtures'] * 6]),
                                   name='output_bias'),
        'rnn_cell': rnn_cell.MultiRNNCell(cells),
        'lr': lr,
        'x': x,
        'y': y,
        'keep_prob': tf.placeholder(tf.float32),
        'istate': istate
    }

    # The next variables need to be remapped, because we don't have RNN context anymore:
    # RNN/MultiRNNCell/Cell0/LSTMCell/ -> MultiRNNCell/Cell0/LSTMCell/
    # B, W_F_diag, W_O_diag, W_I_diag, W_0
    with tf.variable_scope("RNN"):
        pred = RNN_generative(parameters, x, model, istate)

    model['pred'] = pred[0]
    model['last_state'] = pred[1]

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号