calc_vector.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:LSTMVAE 作者: ashwatthaman 项目源码 文件源码
def weighted_cross_entropy(p,t,weight_arr,sec_arr,weigh_flag=True):
    print("p:{}".format(p.data.shape))
    b = np.zeros(p.shape,dtype=np.float32)
    b[np.arange(p.shape[0]), t] = 1
    soft_arr = F.softmax(p)
    log_arr = -F.log(soft_arr)
    xent = b*log_arr

    #
    # print("sec_arr:{}".format(sec_arr))
    # print("xent_shape:{}".format(xent.data.shape))
    xent = F.split_axis(xent,sec_arr,axis=0)
    print([xent_e.data.shape[0] for xent_e in xent])
    x_sum = [F.reshape(F.sum(xent_e)/xent_e.data.shape[0],(1,1)) for xent_e in xent]
    # print("x_sum:{}".format([x_e.data for x_e in x_sum]))
    xent = F.concat(x_sum,axis=0)
    #
    # print("xent1:{}".format(xent.data))
    xent = F.max(xent,axis=1)/p.shape[0]
    # print("xent2:{}".format(xent.data))
    if not weigh_flag:
        return F.sum(xent)
    # print("wei_arr:{}".format(weight_arr))
    # print("wei_arr:{}".format(weight_arr.data.shape))

    print("xent3:{}".format(xent.data.shape))
    wxent= F.matmul(weight_arr,xent,transa=True)
    wxent = F.sum(F.sum(wxent,axis=0),axis=0)
    print("wxent:{}".format(wxent.data))
    return wxent
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号