noTraceBaseline.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:TikZ 作者: ellisk42 项目源码 文件源码
def sample(self, features):
        result = ["START"]

        # (1,1,F)
        features = features.view(-1).unsqueeze(0).unsqueeze(0)
        #features: 1x1x2560

        states = None

        while True:
            e = self.embedding(variable([symbolToIndex[result[-1]]]).view((1,-1)))
            recurrentInput = torch.cat((features,e),2)
            output, states = self.rnn(recurrentInput,states)
            distribution = self.tokenPrediction(output).view(-1)
            distribution = F.log_softmax(distribution).data.exp()
            draw = torch.multinomial(distribution,1)[0]
            c = LEXICON[draw]
            if len(result) > 20 or c == "END":
                return result[1:]
            else:
                result.append(c)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号