model_common.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:LSTMVAE 作者: ashwatthaman 项目源码 文件源码
def predict(self,batch,randFlag):
        t = [[bi] for bi in [1] * batch]
        t = self.makeEmbedBatch(t)

        ys_d = self.dec(t, train=False)
        ys_w = [self.h2w(y) for y in ys_d]
        name_arr_arr = []
        if randFlag:
            t = [predictRandom(F.softmax(y_each)) for y_each in ys_w]
        else:
            t = [y_each.data[-1].argmax(0) for y_each in ys_w]
        name_arr_arr.append(t)
        t = [self.embed(xp.array([t_each], dtype=xp.int32)) for t_each in t]
        count_len = 0
        while count_len < 50:
            ys_d = self.dec(t, train=False)
            ys_w = [self.h2w(y) for y in ys_d]
            if randFlag:
                t = [predictRandom(F.softmax(y_each)) for y_each in ys_w]
            else:
                t = [y_each.data[-1].argmax(0) for y_each in ys_w]
            name_arr_arr.append(t)
            t = [self.embed(xp.array([t_each], dtype=xp.int32)) for t_each in t]
            count_len += 1
        tenti = xp.array(name_arr_arr).T
        for name in tenti:
            name = [self.vocab.itos(nint) for nint in name]
            if "</s>" in name:
                print("     Gen:{}".format("".join(name[:name.index("</s>")])))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号