cnn.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:deep-crf 作者: aonotas 项目源码 文件源码
def compute_vecs(self, word_ids, word_boundaries, phrase_num,
                     char_vecs=None):
        word_ids = my_variable(word_ids, volatile=not self.train)
        word_embs = self.emb(word_ids)  # total_len x dim
        word_embs_reshape = F.reshape(word_embs, (1, 1, -1, self.emb_dim))

        if self.word_level_flag and char_vecs is not None:
            # print(char_vecs.data.shape)
            # print(word_embs.data.shape)
            word_embs = F.concat([word_embs, char_vecs], axis=1)
            # print(word_embs.data.shape)
            dim = self.emb_dim + self.add_dim
            word_embs_reshape = F.reshape(word_embs, (1, 1, -1, dim))

        # 1 x 1 x total_len x dim
        # convolution
        word_emb_conv = self.conv(word_embs_reshape)
        # 1 x dim x total_len x 1
        word_emb_conv_reshape = F.reshape(word_emb_conv,
                                          (self.hidden_dim, -1))
        # max
        word_emb_conv_reshape = F.split_axis(word_emb_conv_reshape,
                                             word_boundaries, axis=1)

        embs = [F.max(word_emb_conv_word, axis=1)
                for i, word_emb_conv_word in enumerate(word_emb_conv_reshape) if i % 2 == 1]
        embs = F.concat(embs, axis=0)
        phrase_emb_conv = F.reshape(embs,
                                    (phrase_num, self.hidden_dim))
        return phrase_emb_conv
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号