combined_merge_predict.py 文件源码

python
阅读 14 收藏 0 点赞 0 评论 0

项目:video-action-recognition 作者: ap916 项目源码 文件源码
def CNN():
    input_frames=10
    batch_size=10
    nb_classes = 101
    nb_epoch = 10
    img_rows, img_cols = 150,150
    img_channels = 2*input_frames
    chunk_size=8
    requiredLines = 1000
    total_predictions = 0
    correct_predictions = 0

    print 'Loading dictionary...'
    with open('../dataset/temporal_test_data.pickle','rb') as f1:
        temporal_test_data=pickle.load(f1)

    t_model = prepareTemporalModel(img_channels,img_rows,img_cols,nb_classes)
    f_model = prepareFeaturesModel(nb_classes,requiredLines)

    merged_layer = Merge([t_model, f_model], mode='ave')
    model = Sequential()
    model.add(merged_layer)
    model.add(Dense(nb_classes, W_regularizer=l2(0.01)))
    model.add(Activation('softmax'))
    model.load_weights('combined_merge_model.h5')

    print 'Compiling model...'
    gc.collect()
    sgd = SGD(lr=0.0001, decay=1e-6, momentum=0.9, nesterov=True,clipnorm=0.1)
    model.compile(loss='hinge',optimizer=sgd, metrics=['accuracy'])

    keys=temporal_test_data.keys()
    random.shuffle(keys)

    # Starting the training of the final model.
    for chunk in chunks(keys,chunk_size):

        tX_test,tY_test=t_getTrainData(chunk,nb_classes,img_rows,img_cols)
        fX_test,fY_test=f_getTrainData(chunk,nb_classes,requiredLines)
        if (tX_test is not None and fX_test is not None):
                preds = model.predict([tX_test,fX_test])
                print (preds)
                print ('-'*40)
                print (tY_test)

                total_predictions += fX_test.shape[0]
                correct_predictions += totalCorrectPred(preds,tY_test)

                comparisons=[]
                maximum=np.argmax(tY_test,axis=1)
                for i,j in enumerate(maximum):
                    comparisons.append(preds[i][j])
                with open('compare.txt','a') as f1:
                    f1.write(str(comparisons))
                    f1.write('\n\n')
    print "\nThe accuracy was found out to be: ",str(correct_predictions*100/total_predictions)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号