def overall_net(batch_size, channels, height, width, action_size, net_type):
# param = learned_param
n=caffe.NetSpec()
# action
n.frames = L.Input(shape=dict(dim=[batch_size, channels, height, width]))
# Image feature
if net_type == 'action':
param = learned_param
else:
param = frozen_param
n.conv1, n.relu1 = conv_relu(n.frames, 8, 32, stride=4, param=param)
n.conv2, n.relu2 = conv_relu(n.relu1, 4, 64, stride=2, param=param)
n.conv3, n.relu3 = conv_relu(n.relu2, 3, 64, stride=1, param=param)
n.fc4, n.relu4 = fc_relu(n.relu3, 512, param=param)
n.value_q = L.InnerProduct(n.relu4, num_output=action_size, param=param,
weight_filler=dict(type='gaussian', std=0.005),
bias_filler=dict(type='constant', value=1))
if net_type == 'test':
return n.to_proto()
n.filter = L.Input(shape=dict(dim=[batch_size, action_size]))
# operation 0: PROD
n.filtered_value_q = L.Eltwise(n.value_q, n.filter, operation=0)
n.target = L.Input(shape=dict(dim=[batch_size, action_size]))
n.loss = L.EuclideanLoss(n.filtered_value_q, n.target)
return n.to_proto()
### define solver
评论列表
文章目录