def myloss (random_int,type): # need to make sure input type
if print_progress: random_int = tf.Print(random_int, ['random in cls',random_int])
condition1 = random_int[0] > tf.constant(rand_threshold[1])
condition0 = random_int[0] > tf.constant(rand_threshold[0])
condition_default = condition0 & condition1
if type =='cls':
def lossfun(y_true, y_pred):
if print_progress: condition = tf.Print(condition1, ['rand int[0]:', random_int[0],
' tf.constant:', tf.constant(rand_threshold[1]),
' condition1:', condition1 ])
val= tf.case({ condition1: lambda: K.mean(K.square(y_pred - y_true), axis=-1),
condition0: lambda: 0 * K.mean(K.square(y_true), axis=-1)
},
default=lambda:0 * K.mean(K.square(y_true), axis=-1),
exclusive=False )
if print_progress: val = tf.Print(val, ['cls loss out:',val,
' rand int received:',random_int,
'condition',condition1])
val.set_shape(K.mean(K.square(y_true), axis=-1).shape)
return val
elif type =='roi':
def lossfun(y_true, y_pred):
if print_progress: condition = tf.Print(condition1, ['rand int[0]:', random_int[0],
' tf.constant:', tf.constant(rand_threshold),
' condition:', condition1])
val= tf.case({ condition1: lambda: 0 * K.mean(K.square(y_true), axis=-1),
condition0: lambda: K.mean(K.square(y_pred - y_true), axis=-1)
},
default=lambda: 0 * K.mean(K.square(y_true), axis=-1),exclusive=False)
if print_progress: val = tf.Print(val, ['roi loss out :', val,
' rand int received:', random_int,
'condition', condition1])
val.set_shape(K.mean(K.square(y_true), axis=-1).shape)
return val
else :
def lossfun(y_true, y_pred):
if print_progress: condition = tf.Print(condition1, ['rand int[0]:', random_int[0],
' tf.constant:', tf.constant(rand_threshold),
' condition:', condition1])
val = tf.case({condition1: lambda: 0 * K.mean(K.square(y_true), axis=-1),
condition0: lambda: 0 * K.mean(K.square(y_true), axis=-1)
},
default=lambda: K.mean(K.square(y_pred - y_true), axis=-1),exclusive=False)
val.set_shape(K.mean(K.square(y_true), axis=-1).shape)
if print_progress: val = tf.Print(val, ['pts loss out :', val,
' rand int received:', random_int,
'condition', condition1])
return val
return lossfun
评论列表
文章目录