pointer_net.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:ReLiefParser 作者: XuezheMax 项目源码 文件源码
def __init__(self, vsize, esize, hsize, asize, buckets, **kwargs):
        super(PointerNet, self).__init__()

        self.name  = kwargs.get('name', self.__class__.__name__)
        self.scope = kwargs.get('scope', self.name)

        self.enc_vsize = vsize
        self.enc_esize = esize
        self.enc_hsize = hsize

        self.dec_msize = self.enc_hsize * 2  # concatenation of bidirectional RNN states
        self.dec_isize = self.enc_hsize * 2  # concatenation of bidirectional RNN states
        self.dec_hsize = hsize
        self.dec_asize = asize

        self.buckets = buckets
        self.max_len = self.buckets[-1]

        self.max_grad_norm = kwargs.get('max_grad_norm', 100)
        self.optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
        # self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-2)

        self.num_layer = kwargs.get('num_layer', 1)
        self.rnn_class = kwargs.get('rnn_class', tf.nn.rnn_cell.BasicLSTMCell)
        # self.rnn_class = kwargs.get('rnn_class', tf.nn.rnn_cell.GRUCell)

        self.encoder = Encoder(self.enc_vsize, self.enc_esize, self.enc_hsize, 
                               rnn_class=self.rnn_class, num_layer = self.num_layer)

        if kwargs.get('tree_decoder', False):
            self.decoder = TreeDecoder(self.dec_isize, self.dec_hsize, self.dec_msize, self.dec_asize, self.max_len, 
                                       rnn_class=self.rnn_class, num_layer = self.num_layer, epsilon=1.0)
        else:
            self.decoder = Decoder(self.dec_isize, self.dec_hsize, self.dec_msize, self.dec_asize, self.max_len, 
                                   rnn_class=self.rnn_class, num_layer = self.num_layer, epsilon=1.0)

        self.baselines = []
        self.bl_ratio = kwargs.get('bl_ratio', 0.95)
        for i in range(self.max_len):
            self.baselines.append(tf.Variable(0.0, trainable=False))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号