def buildmodelparams(self, hyper, checkpointdir=None):
"""Builds model parameters from given hyperparameters and charset size.
Optionally saves checkpoint immediately after building if path specified.
"""
useclass = self.modeltypes[self.modeltype]
self.model = useclass(hyper)
if checkpointdir:
# Compile training functions
self.model._build_t()
# Get initial loss estimate
stderr.write("Calculating initial loss estimate...\n")
# We don't need anything fancy or long, just a rough baseline
data_len = self.valid.batchepoch(16)
loss_len = 20 if data_len >= 20 else data_len
loss = self.model.calc_loss(self.valid, 0, batchsize=8, num_examples=loss_len)
stderr.write("Initial loss: {0:.3f}\n".format(loss))
stderr.write("Initial log loss: {0:.3f}\n".format(log(loss)))
# Take checkpoint
self.newcheckpoint(loss, savedir=checkpointdir)
评论列表
文章目录