generate.py 文件源码

python
阅读 52 收藏 0 点赞 0 评论 0

项目:Tree-LSTM-LM 作者: vgene 项目源码 文件源码
def generate(model, prime_str='A', predict_len=100, temperature=0.8, cuda=False):
    hidden = model.init_hidden(1)
    tensor = char_tensor(prime_str, model.mapping)
    prime_input = Variable(tensor.unsqueeze(0))
    #print(prime_input)
    if cuda:
        hidden = tuple(h.cuda() for h in hidden)
        prime_input = prime_input.cuda()
    predicted = prime_str
    model.seq_length = 1

    #print(hidden)
    #print(prime_input[:,0])
    # Use priming string to "build up" hidden state
    for p in range(len(prime_str) - 1):
        _, hidden = model(prime_input[:,p], hidden)

    inp = prime_input[:,-1]

    for p in range(predict_len):
        output, hidden = model(inp, hidden)

        # Sample from the network as a multinomial distribution
        output_dist = output.data.view(-1).div(temperature).exp()
        top_i = torch.multinomial(output_dist, 1)[0]

        # Add predicted character to string and use as next input
        predicted_char = model.mapping[top_i]
        predicted += predicted_char
        inp = Variable(char_tensor(predicted_char, model.mapping).unsqueeze(0))
        if cuda:
            inp = inp.cuda()

    return predicted

# Run as standalone script
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号