main.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:SentimentAnalysis 作者: Conchylicultor 项目源码 文件源码
def main(outputName):
    print("Welcome into RNTN implementation 0.6 (recording will be on ", outputName, ")")

    random.seed("MetaMind") # Lucky seed ? Fixed seed for replication
    np.random.seed(7)

    print("Parsing dataset, creating dictionary...")
    # Dictionary initialisation
    vocabulary.initVocab()

    # Loading dataset
    datasets = {}
    datasets['training'] = utils.loadDataset("trees/train.txt");
    print("Training loaded !")
    datasets['testing'] = utils.loadDataset("trees/test.txt");
    print("Testing loaded !")
    datasets['validating'] = utils.loadDataset("trees/dev.txt");
    print("Validation loaded !")

    print("Datasets loaded !")
    print("Nb of words", vocabulary.vocab.length());

    # Datatransform (normalisation, remove outliers,...) ?? > Not here

    # Grid search on our hyperparameters (too long for complete k-fold cross validation so just train/test)
    for mBS in miniBatchSize:
        for aRNI in adagradResetNbIter:
            for lR in learningRate:
                for rT in regularisationTerm:
                    params = {}
                    params["nbEpoch"]            = nbEpoch
                    params["learningRate"]       = lR
                    params["regularisationTerm"] = rT
                    params["adagradResetNbIter"] = aRNI
                    params["miniBatchSize"]      = mBS
                    # No need to reset the vocabulary values (contained in model.L so automatically reset)
                    # Same for the training and testing set (output values recomputed at each iterations)
                    model, errors = train.train(outputName, datasets, params)

    # TODO: Plot the cross-validation curve
    # TODO: Plot a heat map of the hyperparameters cost to help tunning them ?

    ## Validate on the last computed model (Only used for final training)
    #print("Training complete, validating...")
    #vaError = model.computeError(datasets['validating'], True)
    #print("Validation error: ", vaError)

    print("The End. Thank you for using this program!")
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号