train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:autotrump 作者: Rabrg 项目源码 文件源码
def __init__(self, training_file='../res/trump_tweets.txt', model_file='../res/model.pt', n_epochs=1000000,
                 hidden_size=256, n_layers=2, learning_rate=0.001, chunk_len=140):
        self.training_file = training_file
        self.model_file = model_file
        self.n_epochs = n_epochs
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.learning_rate = learning_rate
        self.chunk_len = chunk_len
        self.file, self.file_len = read_file(training_file)
        if os.path.isfile(model_file):
            self.decoder = torch.load(model_file)
            print('Loaded old model!')
        else:
            self.decoder = RNN(n_characters, hidden_size, n_characters, n_layers)
            print('Constructed new model!')
        self.decoder_optimizer = torch.optim.Adam(self.decoder.parameters(), learning_rate)
        self.criterion = nn.CrossEntropyLoss()
        self.generator = Generator(self.decoder)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号