def create_critic_network(self, state_size, action_dim):
"""create critic network."""
print ("[MESSAGE] Build critic network.""")
S = Input(shape=state_size)
A = Input(shape=(action_dim,))
# input
h_0 = Conv2D(32, (3, 3), padding="same",
kernel_regularizer=l2(0.0001),
activation="relu")(S)
h_1 = Conv2D(32, (3, 3), padding="same",
kernel_regularizer=l2(0.0001),
activation="relu")(h_0)
h_1 = AveragePooling2D(2, 2)(h_1)
h_1 = Flatten()(h_1)
h_1 = Dense(600, activation="relu")(h_1)
# action
a_1 = Dense(600, activation="linear")(A)
h_2 = add([h_1, a_1])
h_3 = Dense(600, activation="relu")(h_2)
V = Dense(action_dim, activation="softmax")(h_3)
model = Model(inputs=[S, A], outputs=V)
model.compile(loss='categorical_crossentropy',
optimizer="adam")
return model, A, S
评论列表
文章目录