def weighted_cross_entropy(p,t,weight_arr,sec_arr,weigh_flag=True):
print("p:{}".format(p.data.shape))
b = np.zeros(p.shape,dtype=np.float32)
b[np.arange(p.shape[0]), t] = 1
soft_arr = F.softmax(p)
log_arr = -F.log(soft_arr)
xent = b*log_arr
#
# print("sec_arr:{}".format(sec_arr))
# print("xent_shape:{}".format(xent.data.shape))
xent = F.split_axis(xent,sec_arr,axis=0)
print([xent_e.data.shape[0] for xent_e in xent])
x_sum = [F.reshape(F.sum(xent_e)/xent_e.data.shape[0],(1,1)) for xent_e in xent]
# print("x_sum:{}".format([x_e.data for x_e in x_sum]))
xent = F.concat(x_sum,axis=0)
#
# print("xent1:{}".format(xent.data))
xent = F.max(xent,axis=1)/p.shape[0]
# print("xent2:{}".format(xent.data))
if not weigh_flag:
return F.sum(xent)
# print("wei_arr:{}".format(weight_arr))
# print("wei_arr:{}".format(weight_arr.data.shape))
print("xent3:{}".format(xent.data.shape))
wxent= F.matmul(weight_arr,xent,transa=True)
wxent = F.sum(F.sum(wxent,axis=0),axis=0)
print("wxent:{}".format(wxent.data))
return wxent
评论列表
文章目录