def make_model(state_shape, n_actions):
in_t = Input(shape=(HISTORY_STEPS,) + state_shape, name='input')
action_t = Input(shape=(1,), dtype='int32', name='action')
advantage_t = Input(shape=(1,), name='advantage')
fl_t = Flatten(name='flat')(in_t)
l1_t = Dense(SIMPLE_L1_SIZE, activation='relu', name='l1')(fl_t)
l2_t = Dense(SIMPLE_L2_SIZE, activation='relu', name='l2')(l1_t)
policy_t = Dense(n_actions, name='policy', activation='softmax')(l2_t)
def loss_func(args):
p_t, act_t, adv_t = args
oh_t = K.one_hot(act_t, n_actions)
oh_t = K.squeeze(oh_t, 1)
p_oh_t = K.log(1e-6 + K.sum(oh_t * p_t, axis=-1, keepdims=True))
res_t = adv_t * p_oh_t
return -res_t
loss_t = Lambda(loss_func, output_shape=(1,), name='loss')([policy_t, action_t, advantage_t])
return Model(input=[in_t, action_t, advantage_t], output=[policy_t, loss_t])
评论列表
文章目录