mnist_tsgl.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:tree-structured-group-lasso 作者: jaesik817 项目源码 文件源码
def train(group):
  # Import data
  mnist = input_data.read_data_sets(FLAGS.data_dir,
                                    one_hot=True,
                                    fake_data=FLAGS.fake_data)
  tr_data, tr_label = mnist.train.next_batch(mnist.train._num_examples);
  # Dictionary Initialization
  M=len(tr_data[0]);
  D=tsgl.dict_initializer(M,FLAGS.P);
  # Learning
  lr=FLAGS.learning_rate; pre_mse=10;
  for i in range(1,FLAGS.max_steps+1):
    # Data Shuffle
    idx=range(len(tr_data));np.random.shuffle(idx);
    batch=tr_data[idx[:FLAGS.batch_num]].transpose();
    # Learning Rate Decay
    if(i%FLAGS.decay_num==0):
      lr=lr/float(FLAGS.decay_rate);
    # Sparse Coding
    A=tsgl.sparse_coding(D,batch,FLAGS,group);
    print(A[:,0]);print(A[:,1]);
    # Dictionary Learning
    D=tsgl.dictionary_learning(D,batch,A,lr,FLAGS);
    loss=np.linalg.norm(np.matmul(D,A)-batch,axis=0);mse=np.mean(loss);
    print(str(i)+"th MSE: "+str(mse));
    mse_diff=abs(mse-pre_mse);
    if(mse_diff<FLAGS.mse_diff_threshold):
      print("Learning Done");
      exit(1);
    pre_mse=mse;
  print("Max Iterations Done");
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号