problem3.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:Machine-Learning 作者: zjuzpz 项目源码 文件源码
def solution3():
    print("Start problem 3")
    print("Start data preparation for problem 3")
    # firstly prepare data for problem 3, all text files will be saved at current path
    data_preparation()
    print("Results: ")

    tags = ['tweets_#gohawks', 'tweets_#gopatriots', 'tweets_#nfl', \
    'tweets_#patriots', 'tweets_#sb49', 'tweets_#superbowl']

    # idx for tag index in tags
    idx=0
    for tag in tags:
        tweets, parameters = [], []
        f = open('problem 3 ' + tag + ' data.txt')
        line = f.readline()
        while len(line):
            p = line.split()
            tweets.append(float(p[0]))
            parameters.append([float(p[i]) for i in range(len(p))])
            line = f.readline()
        del(tweets[0])
        next_hour_tweets = np.array(tweets)
        parameters.pop()
        X = np.matrix(parameters)
        res = sm.OLS(next_hour_tweets, X).fit()
        print("Result of " + tag)
        print(res.summary())
        if tag == 'tweets_#gohawks' or tag=='tweets_#nfl' or tag=='tweets_#patriots':
            scatterPlot(idx,0,1,2,X,next_hour_tweets,tag)
        elif tag ==  'tweets_#gopatriots':
            scatterPlot(idx,1,3,4,X,next_hour_tweets, tag)
        elif tag == 'tweets_#sb49':
            scatterPlot(idx,0,1,5,X,next_hour_tweets,tag)
        elif tag == 'tweets_#superbowl':
            scatterPlot(idx,0,2,5,X, next_hour_tweets,tag)
        idx += 1    
        f.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号