def SGD(x):
global n_updates_acc
global mse
for val in x:
row_block_id = val[0]
v_iter = val[1][0]
w_iter = val[1][1]
h_iter = val[1][2]
# dictionaries to store W and H
w = {xw[0]:xw[1] for xw in w_iter}
h = {xh[0]:xh[1] for xh in h_iter}
# go through V and update W and H
for v_ij in v_iter:
i, j = v_ij
# get row and column
w_i = w[i]
h_j = h[j]
# calculate error
error = 5 - np.dot(w_i,h_j)
# increment MSE
mse += error**2
# gradients with L2 loss
# dictionary values are updated in place
h_update = step_size.value*(-2*error*w_i + 2.0*reg.value*h_j)
h_update_mx = ma.masked_array(h_update, mask.value)
w_update = step_size.value*(-2*error*h_j + 2.0*reg.value*w_i)
h_j -= step_size.value*(-2*error*w_i + 2.0*reg.value*h_j)
w_i -= step_size.value*(-2*error*h_j + 2.0*reg.value*w_i)
# increment num updates
n_updates_acc += 1
# must massage results in something that will return properly
output = {}
for row_index in w:
output[('W', row_index)] = (row_index, w[row_index])
for col_index in h:
output[('H', col_index)] = (col_index, h[col_index])
# return iterator of updated W and H
return tuple((output.items()))
MFSideData.py 文件源码
python
阅读 19
收藏 0
点赞 0
评论 0
评论列表
文章目录