creatNet.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DQN 作者: Ivehui 项目源码 文件源码
def overall_net(batch_size, channels, height, width, action_size, net_type):

    # param = learned_param
    n=caffe.NetSpec()
    # action
    n.frames = L.Input(shape=dict(dim=[batch_size, channels, height, width]))

    # Image feature
    if net_type == 'action':
        param = learned_param
    else:
        param = frozen_param

    n.conv1, n.relu1 = conv_relu(n.frames, 8, 32, stride=4, param=param)
    n.conv2, n.relu2 = conv_relu(n.relu1, 4, 64, stride=2, param=param)
    n.conv3, n.relu3 = conv_relu(n.relu2, 3, 64, stride=1, param=param)
    n.fc4, n.relu4 = fc_relu(n.relu3, 512, param=param)

    n.value_q = L.InnerProduct(n.relu4, num_output=action_size, param=param,
                               weight_filler=dict(type='gaussian', std=0.005),
                               bias_filler=dict(type='constant', value=1))

    if net_type == 'test':
        return n.to_proto()

    n.filter = L.Input(shape=dict(dim=[batch_size, action_size]))
    # operation 0: PROD
    n.filtered_value_q = L.Eltwise(n.value_q, n.filter, operation=0)

    n.target = L.Input(shape=dict(dim=[batch_size, action_size]))

    n.loss = L.EuclideanLoss(n.filtered_value_q, n.target)

    return n.to_proto()

### define solver
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号