def conv_vh_decomposition(model, args):
W = model.arg_params[args.layer+'_weight'].asnumpy()
N, C, y, x = W.shape
b = model.arg_params[args.layer+'_bias'].asnumpy()
W = W.transpose((1,2,0,3)).reshape((C*y, -1))
U, D, Q = np.linalg.svd(W, full_matrices=False)
sqrt_D = LA.sqrtm(np.diag(D))
K = args.K
V = U[:,:K].dot(sqrt_D[:K, :K])
H = Q.T[:,:K].dot(sqrt_D[:K, :K])
V = V.T.reshape(K, C, y, 1)
b_1 = np.zeros((K, ))
H = H.reshape(N, x, 1, K).transpose((0,3,2,1))
b_2 = b
W1, b1, W2, b2 = V, b_1, H, b_2
def sym_handle(data, node):
kernel = eval(node['param']['kernel'])
pad = eval(node['param']['pad'])
name = node['name']
name1 = name + '_v'
kernel1 = tuple((kernel[0], 1))
pad1 = tuple((pad[0], 0))
num_filter = W1.shape[0]
sym1 = mx.symbol.Convolution(data=data, kernel=kernel1, pad=pad1, num_filter=num_filter, name=name1)
name2 = name + '_h'
kernel2 = tuple((1, kernel[1]))
pad2 = tuple((0, pad[1]))
num_filter = W2.shape[0]
sym2 = mx.symbol.Convolution(data=sym1, kernel=kernel2, pad=pad2, num_filter=num_filter, name=name2)
return sym2
def arg_handle(arg_shape_dic, arg_params):
name1 = args.layer + '_v'
name2 = args.layer + '_h'
weight1 = mx.ndarray.array(W1)
bias1 = mx.ndarray.array(b1)
weight2 = mx.ndarray.array(W2)
bias2 = mx.ndarray.array(b2)
assert weight1.shape == arg_shape_dic[name1+'_weight'], 'weight1'
assert weight2.shape == arg_shape_dic[name2+'_weight'], 'weight2'
assert bias1.shape == arg_shape_dic[name1+'_bias'], 'bias1'
assert bias2.shape == arg_shape_dic[name2+'_bias'], 'bias2'
arg_params[name1 + '_weight'] = weight1
arg_params[name1 + '_bias'] = bias1
arg_params[name2 + '_weight'] = weight2
arg_params[name2 + '_bias'] = bias2
new_model = utils.replace_conv_layer(args.layer, model, sym_handle, arg_handle)
return new_model
评论列表
文章目录