networks.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:kerlym 作者: osh 项目源码 文件源码
def simple_cnn(agent, env, dropout=0, learning_rate=1e-3, **args):
  with tf.device("/cpu:0"):
    state = tf.placeholder('float', [None, agent.input_dim])
    S = Input(shape=[agent.input_dim])
    h = Reshape( agent.input_dim_orig )(S)
    h = TimeDistributed( Convolution2D(16, 8, 8, subsample=(4, 4), border_mode='same', activation='relu', dim_ordering='tf'))(h)
#    h = Dropout(dropout)(h)
    h = TimeDistributed( Convolution2D(32, 4, 4, subsample=(2, 2), border_mode='same', activation='relu', dim_ordering='tf'))(h)
    h = Flatten()(h)
#    h = Dropout(dropout)(h)
    h = Dense(256, activation='relu')(h)
#    h = Dropout(dropout)(h)
    h = Dense(128, activation='relu')(h)
    V = Dense(env.action_space.n, activation='linear',init='zero')(h)
    model = Model(S, V)
    model.compile(loss='mse', optimizer=RMSprop(lr=learning_rate) )
    return state, model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号