model.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:teras 作者: chantera 项目源码 文件源码
def __call__(self, chars):
        if not isinstance(chars, (tuple, list)):
            chars = [chars]
        char_ids, boundaries = self._create_sequence(chars)
        x = self.embed(self.xp.array(char_ids))
        x = F.dropout(x, self._dropout)
        length, dim = x.shape
        C = self.conv(F.reshape(x, (1, 1, length, dim)))
        # C.shape -> (1, out_size, length, 1)
        C = F.split_axis(F.transpose(F.reshape(C, (self.out_size, length))),
                         boundaries, axis=0)
        ys = F.max(F.pad_sequence(
            [matrix for i, matrix in enumerate(C) if i % 2 == 1],
            padding=-np.inf), axis=1)  # max over time pooling
        # assert len(chars) == ys.shape[0]
        return ys
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号