reinforcement.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:detection-2016-nipsws 作者: imatge-upc 项目源码 文件源码
def get_q_network(weights_path):
    model = Sequential()
    model.add(Dense(1024, init=lambda shape, name: normal(shape, scale=0.01, name=name), input_shape=(25112,)))
    model.add(Activation('relu'))
    model.add(Dropout(0.2))
    model.add(Dense(1024, init=lambda shape, name: normal(shape, scale=0.01, name=name)))
    model.add(Activation('relu'))
    model.add(Dropout(0.2))
    model.add(Dense(6, init=lambda shape, name: normal(shape, scale=0.01, name=name)))
    model.add(Activation('linear'))
    adam = Adam(lr=1e-6)
    model.compile(loss='mse', optimizer=adam)
    if weights_path != "0":
        model.load_weights(weights_path)
    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号