train.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:char-rnn 作者: hiepph 项目源码 文件源码
def generate(self, prime_str, predict_len=100, temperature=0.8):
        predicted = prime_str

        hidden = self.decoder.init_hidden()
        prime_input = char_tensor(prime_str, self.decoder.gpu)

        # Use prime string to build up hidden state
        for p in range(len(prime_str) - 1):
            _, hidden = self.decoder(prime_input[p], hidden)

        inp  = prime_input[-1]
        for p in range(predict_len):
            out, hidden = self.decoder(inp, hidden)

            # sample from network as a multinomial distribution out_dist = out.data.view(-1).div(temperature).exp()
            out_dist = out.data.view(-1).div(temperature).exp()
            top_i = torch.multinomial(out_dist, 1)[0]

            # Add predicted character to string and use as next input
            predicted_char = all_characters[top_i]
            predicted += predicted_char
            inp = char_tensor(predicted_char, self.decoder.gpu)

        return predicted
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号