def __init__(self, ob_space, ac_space, meta_ac_space):
with tf.variable_scope('conv'):
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
x = tf.nn.relu( conv2d(x, 16, "l1", [8, 8], [4, 4]) )
x = tf.nn.relu( conv2d(x, 32, "l2", [4, 4], [2, 2]) )
# x is [?, 11, 11, 32]
self.conv_feature = tf.reduce_mean(x, axis=[1,2])
x = tf.nn.relu(linear(flatten(x), 256, "hidden", normalized_columns_initializer(1.0)))
self.prev_action = prev_action = tf.placeholder(tf.float32, [None, ac_space], "prev_a")
self.prev_reward = prev_reward = tf.placeholder(tf.float32, [None, 1], "prev_r")
# concat previous action and reward
x = tf.concat([x, prev_action], axis=1)
x = tf.concat([x, prev_reward], axis=1)
self.meta_action = meta_action = tf.placeholder(tf.float32, [None, meta_ac_space], "meta_action")
# concat
x = tf.concat([x, meta_action], axis=1)
# introduce a "fake" batch dimension of 1 after flatten so that we can do LSTM over time dim
x = tf.expand_dims(x, [0])
with tf.variable_scope('lstm'):
size = 256
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action", normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value", normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
# Note: need to be on scope of the class
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
评论列表
文章目录