def train(group):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir,
one_hot=True,
fake_data=FLAGS.fake_data)
tr_data, tr_label = mnist.train.next_batch(mnist.train._num_examples);
# Dictionary Initialization
M=len(tr_data[0]);
D=tsgl.dict_initializer(M,FLAGS.P);
# Learning
lr=FLAGS.learning_rate; pre_mse=10;
for i in range(1,FLAGS.max_steps+1):
# Data Shuffle
idx=range(len(tr_data));np.random.shuffle(idx);
batch=tr_data[idx[:FLAGS.batch_num]].transpose();
# Learning Rate Decay
if(i%FLAGS.decay_num==0):
lr=lr/float(FLAGS.decay_rate);
# Sparse Coding
A=tsgl.sparse_coding(D,batch,FLAGS,group);
print(A[:,0]);print(A[:,1]);
# Dictionary Learning
D=tsgl.dictionary_learning(D,batch,A,lr,FLAGS);
loss=np.linalg.norm(np.matmul(D,A)-batch,axis=0);mse=np.mean(loss);
print(str(i)+"th MSE: "+str(mse));
mse_diff=abs(mse-pre_mse);
if(mse_diff<FLAGS.mse_diff_threshold):
print("Learning Done");
exit(1);
pre_mse=mse;
print("Max Iterations Done");
mnist_tsgl.py 文件源码
python
阅读 21
收藏 0
点赞 0
评论 0
评论列表
文章目录