def createLayers():
x = Input(shape=env.observation_space.shape)
if args.batch_norm:
h = BatchNormalization()(x)
else:
h = x
for i in xrange(args.layers):
h = Dense(args.hidden_size, activation=args.activation)(h)
if args.batch_norm and i != args.layers - 1:
h = BatchNormalization()(h)
y = Dense(env.action_space.n + 1)(h)
if args.advantage == 'avg':
z = Lambda(lambda a: K.expand_dims(a[:, 0], dim=-1) + a[:, 1:] - K.mean(a[:, 1:], keepdims=True), output_shape=(env.action_space.n,))(y)
elif args.advantage == 'max':
z = Lambda(lambda a: K.expand_dims(a[:, 0], dim=-1) + a[:, 1:] - K.max(a[:, 1:], keepdims=True), output_shape=(env.action_space.n,))(y)
elif args.advantage == 'naive':
z = Lambda(lambda a: K.expand_dims(a[:, 0], dim=-1) + a[:, 1:], output_shape=(env.action_space.n,))(y)
else:
assert False
return x, z
评论列表
文章目录