pruneVGG.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:nn-compression 作者: anithapk 项目源码 文件源码
def fineTuneNet(X_train,y_train,BATCH_SIZE,images,y,LR,RC,train_mode,learnRate,regConst,mask,layList,KP,keepProb,indMask):
    nLayers = len(mask)
    print(["Function fineTune",nLayers])
    sess = tf.get_default_session()
    n_train = len(y_train)
    X_train, y_train = shuffle(X_train, y_train)
    for offset in range(0, n_train, BATCH_SIZE):
        end = offset + BATCH_SIZE
        batch_x = utils.load_image(X_train[offset:end])
        batch_y = y_train[offset:end]
        #print("in FT:",chkWts(layList,indMask,layType))
        if (nLayers==1):
            # drop outs applied only for fully connected
            rat = float(np.prod(mask[0].shape)-len(indMask[0][0]))/float(np.prod(mask[0].shape))
            doAdj = keepProb*np.sqrt(rat)
            sess.run(applygrad0,feed_dict={Mask0:mask[0],
                images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:doAdj, train_mode:True})
        elif (nLayers==2):
            sess.run(applygrad1,feed_dict={Mask1:mask[1],
                images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb, train_mode:True})
            sess.run(applygrad0,feed_dict={Mask0:mask[0],
                images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
        elif (nLayers==4):
            sess.run(applygrad3,feed_dict={Mask3:mask[3],
                images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
            sess.run(applygrad2,feed_dict={Mask2:mask[2],
                images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
            sess.run(applygrad1,feed_dict={Mask1:mask[1],
                images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})
            sess.run(applygrad0,feed_dict={Mask0:mask[0],
                images: batch_x, y: batch_y, LR:learnRate, RC: regConst, KP:keepProb,train_mode:True})    
        else:
            print("wrong number of layers passed")
            break
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号