pointer_net.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:ReLiefParser 作者: XuezheMax 项目源码 文件源码
def __call__(self, enc_input, dec_input_indices, valid_indices, left_indices, right_indices, values, valid_masks=None):
        batch_size = tf.shape(enc_input)[0]
        # forward computation graph
        with tf.variable_scope(self.scope):
            # encoder output
            enc_memory, enc_final_state_fw, _ = self.encoder(enc_input)

            # decoder
            dec_hiddens, dec_actions, dec_act_logps = self.decoder(
                                                            enc_memory, dec_input_indices, 
                                                            valid_indices, left_indices, right_indices,
                                                            valid_masks, init_state=enc_final_state_fw)

            # cost
            costs = []
            update_ops = []
            for step_idx, (act_logp, value, baseline) in enumerate(zip(dec_act_logps, values, self.baselines)):
                # costs.append(-tf.reduce_mean(act_logp * (value - baseline)))
                new_baseline = self.bl_ratio * baseline + (1-self.bl_ratio) * tf.reduce_mean(value)
                costs.append(-tf.reduce_mean(act_logp * value))
                update_ops.append(tf.assign(baseline, new_baseline))

        # gradient computation graph
        self.params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope)
        train_ops = []
        for limit in self.buckets:
            print '0 ~ %d' % (limit-1)
            grad_params = tf.gradients(tf.reduce_sum(tf.pack(costs[:limit])), self.params)
            if self.max_grad_norm is not None:
                clipped_gradients, norm = tf.clip_by_global_norm(grad_params, self.max_grad_norm)
            else:
                clipped_gradients = grad_params
            train_op = self.optimizer.apply_gradients(
                            zip(clipped_gradients, self.params))
            with tf.control_dependencies([train_op] + update_ops[:limit]):
                # train_ops.append(tf.Print(tf.constant(1.), [norm]))
                train_ops.append(tf.constant(1.))

        return dec_hiddens, dec_actions, train_ops

#### test script
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号