player_and_opponent_policy_nets.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:WaNN 作者: TeoZosa 项目源码 文件源码
def RNN(layer_in, num_hidden_layers, num_hidden_units, num_inputs_in=155):
    layer_in = tf.reshape(layer_in, [-1, 8 * 8])

    n_features = layer_in.get_shape().as_list()[1]
    num_inputs_in = 155
    num_classes = 155
    # reshape to [1, n_input]
    X = tf.reshape(layer_in, [-1, n_features])

    # Generate a n_input-element sequence of inputs
    # (eg. [had] [a] [general] -> [20] [6] [33])
    X = tf.split(X, n_features, 1)

    # 1-layer LSTM with n_hidden units.

    # rnn_cell = rnn.BasicLSTMCell(num_hidden)

    rnn_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(num_hidden_units)] * num_hidden_layers)


    # generate prediction
    outputs, states = rnn.static_rnn(rnn_cell, X, dtype=tf.float32)

    # there are n_input outputs but
    # we only want the last output
    weights = {
        'out': tf.Variable(tf.random_normal([num_hidden_units, num_classes]))
    }
    biases = {
        'out': tf.Variable(tf.random_normal([num_classes]))
    }
    return tf.matmul(outputs[-1], weights['out']) + biases['out']
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号