def adjust_length(pred_line, lineN, max_length):
"""
Messy function that handles problems that arise if predictions for the same example have different lengths
which may happen due to using a different batch size for each model. Normally it shouldn't be needed.
:param pred_line:
:param lineN:
:param max_length:
:return:
"""
pred_line = numpy.trim_zeros(pred_line, trim='b')
# The following takes care of lines that are shorter than the ones for previous files due to 0-trimming
if lineN > len(max_length):
maxLen = numpy.append(max_length, len(pred_line))
while len(pred_line) < maxLen[lineN - 1]:
pred_line = numpy.append(pred_line, 0)
# print "Tail zero added to line "+str(lineN)+" of "+pred_file
if len(pred_line) > maxLen[lineN - 1]:
print '!!! Warning: Line ' + str(lineN) + ' is longer than the corresponding lines of previous files.'
maxLen[lineN - 1] = len(pred_line)
return pred_line, max_length
评论列表
文章目录