model_batcha3c.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:batchA3C 作者: ssamot 项目源码 文件源码
def build_network(num_actions, agent_history_length, resized_width, resized_height):
    state = tf.placeholder("float", [None, agent_history_length, resized_width, resized_height])

    inputs_v = Input(shape=(agent_history_length, resized_width, resized_height,))
    #model_v  = Permute((2, 3, 1))(inputs_v)

    model_v = Convolution2D(nb_filter=16, nb_row=8, nb_col=8, subsample=(4,4), activation='relu', border_mode='same')(inputs_v)
    model_v = Convolution2D(nb_filter=32, nb_row=4, nb_col=4, subsample=(2,2), activation='relu', border_mode='same')(model_v)
    model_v = Flatten()(model_v)
    model_v = Dense(output_dim=512)(model_v)
    model_v = PReLU()(model_v)


    action_probs = Dense(name="p", output_dim=num_actions, activation='softmax')(model_v)

    state_value = Dense(name="v", output_dim=1, activation='linear')(model_v)


    value_network = Model(input=inputs_v, output=[state_value, action_probs])


    return state, value_network
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号