def getModel(input_shape):
Qfunc = Sequential()
action = Input(shape=(Player.NUM_ACTIONS,))
screen = Input(shape=input_shape)
Qfunc.add(Conv2D(32, (8, 8), strides=(4, 4),
activation='relu', padding="same",
input_shape=input_shape,
data_format="channels_first"))
Qfunc.add(Conv2D(64, (4, 4), strides=(2, 2),
activation='relu', padding="same",
data_format="channels_first"))
Qfunc.add(Conv2D(64, (4, 4), activation='relu', padding="same",
data_format="channels_first"))
Qfunc.add(MaxPooling2D(pool_size=(2, 2)))
Qfunc.add(Flatten())
Qfunc.add(Dense(128, activation='relu'))
Qfunc.add(Dense(128, activation='relu'))
Qfunc.add(Dense(Player.NUM_ACTIONS))
reward = Qfunc(screen)
model = Model(inputs=[screen, action],
outputs=dot([reward, action], axes=[1, 1]))
return Qfunc, model
评论列表
文章目录