def get_q_network(weights_path):
model = Sequential()
model.add(Dense(1024, init=lambda shape, name: normal(shape, scale=0.01, name=name), input_shape=(25112,)))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(1024, init=lambda shape, name: normal(shape, scale=0.01, name=name)))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(6, init=lambda shape, name: normal(shape, scale=0.01, name=name)))
model.add(Activation('linear'))
adam = Adam(lr=1e-6)
model.compile(loss='mse', optimizer=adam)
if weights_path != "0":
model.load_weights(weights_path)
return model
评论列表
文章目录