def calculate_batchsize_maxlen(texts):
""" Calculates the maximum length in the provided texts and a suitable
batch size. Rounds up maxlen to the nearest multiple of ten.
# Arguments:
texts: List of inputs.
# Returns:
Batch size,
max length
"""
def roundup(x):
return int(math.ceil(x / 10.0)) * 10
# Calculate max length of sequences considered
# Adjust batch_size accordingly to prevent GPU overflow
lengths = [len(tokenize(t)) for t in texts]
maxlen = roundup(np.percentile(lengths, 80.0))
batch_size = 250 if maxlen <= 100 else 50
return batch_size, maxlen
评论列表
文章目录