def sample(self, features):
result = ["START"]
# (1,1,F)
features = features.view(-1).unsqueeze(0).unsqueeze(0)
#features: 1x1x2560
states = None
while True:
e = self.embedding(variable([symbolToIndex[result[-1]]]).view((1,-1)))
recurrentInput = torch.cat((features,e),2)
output, states = self.rnn(recurrentInput,states)
distribution = self.tokenPrediction(output).view(-1)
distribution = F.log_softmax(distribution).data.exp()
draw = torch.multinomial(distribution,1)[0]
c = LEXICON[draw]
if len(result) > 20 or c == "END":
return result[1:]
else:
result.append(c)
评论列表
文章目录