dern.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:der-network 作者: soskek 项目源码 文件源码
def encode_tokens(self, x_datas, i2sD, train=True):
        # Embed, dropout, split into each token (batchsize=1)
        h0L = list(F.split_axis(
            F.dropout(
                self.embed(chainer.Variable(self.xp.array(x_datas, dtype=np.int32), volatile=not train)),
                ratio=self.dropout_ratio, train=train), len(x_datas), axis=0))

        # Replace embedding with dynamic entity representation
        for i in i2sD.keys():
            h0L[i] = self.W_dx(i2sD[i])

        # LSTM. forward order
        forward_outL = []
        self.f_LSTM.reset_state()
        for h0 in h0L:
            state = self.f_LSTM(h0)
            forward_outL.append(state)

        # LSTM. backward order
        backward_outL = []
        self.b_LSTM.reset_state()
        for h0 in reversed(h0L):
            state = self.b_LSTM(h0)
            backward_outL.append(state)

        return forward_outL, backward_outL
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号