def write_predictions(self, inputs):
'''
Outputs predictions in a file named <model_name_prefix>.predictions.
'''
predictions = numpy.argmax(self.model.predict(inputs), axis=1)
test_output_file = open("%s.predictions" % self.model_name_prefix, "w")
for input_indices, prediction in zip(inputs, predictions):
# The predictions are indices of words in padded sentences. We need to readjust them.
padding_length = 0
for index in input_indices:
if numpy.all(index == 0):
padding_length += 1
else:
break
prediction = prediction - padding_length + 1 # +1 because the indices start at 1.
print >>test_output_file, prediction
评论列表
文章目录