def compute_perplexity(model, buckets, batchsize=100):
result = []
for bucket_index, dataset in enumerate(buckets):
ppl = []
# split into minibatch
if len(dataset) > batchsize:
num_sections = len(dataset) // batchsize - 1
if len(dataset) % batchsize > 0:
num_sections += 1
indices = [(i + 1) * batchsize for i in range(num_sections)]
sections = np.split(dataset, indices, axis=0)
else:
sections = [dataset]
# compute accuracy
for batch_index, batch in enumerate(sections):
sys.stdout.write("\rcomputing perplexity ... bucket {}/{} (batch {}/{})".format(bucket_index + 1, len(buckets), batch_index + 1, len(sections)))
sys.stdout.flush()
ppl.append(compute_perplexity_batch(model, batch))
result.append(sum(ppl) / len(ppl))
sys.stdout.write("\r" + stdout.CLEAR)
sys.stdout.flush()
return result
评论列表
文章目录