def detect_nan(i, node, fn):
'''
x = theano.tensor.dscalar('x')
f = theano.function([x], [theano.tensor.log(x) * x],
mode=theano.compile.MonitorMode(post_func=detect_nan))
'''
nan_detected = False
for output in fn.outputs:
if np.isnan(output[0]).any():
nan_detected = True
np.set_printoptions(threshold=np.nan) # Print the whole arrays
print '*** NaN detected ***'
print '--------------------------NODE DESCRIPTION:'
theano.printing.debugprint(node)
print '--------------------------Variables:'
print 'Inputs : %s' % [input[0] for input in fn.inputs]
print 'Outputs: %s' % [output[0] for output in fn.outputs]
break
if nan_detected:
exit()
评论列表
文章目录