helpers.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:Quant-StockPredictor 作者: echapuis 项目源码 文件源码
def compute_portvals2(orders_file="train.csv", start_val=100000):
    # this is the function the autograder will call to test your code
    # TODO: Your code here
    of= open(orders_file, 'w+')
    of.write('Date,Symbol,Order,Shares\n')
    if orders_file == "train.csv":
        of.write("2006-01-03,IBM,BUY,500\n")
        of.write("2009-12-31,IBM,SELL,500\n")
    else:
        of.write("2010-01-04,IBM,BUY,500\n")
        of.write("2010-12-31,IBM,SELL,500\n")
    of.close()

    orders_df = pd.read_csv(orders_file, index_col='Date', parse_dates=True, na_values=['nan'])
    os.remove(orders_file)
    # In the template, instead of computing the value of the portfolio, we just
    # read in the value of IBM over 6 months
    start_date = orders_df.index[0]
    end_date = orders_df.index[-1]
    orders_df.index = orders_df.index - start_date
    bad_index = datetime(2011, 6, 15) - start_date

    cos = orders_df['Symbol'].unique()
    portvals = get_data(cos, pd.date_range(start_date, end_date), addSPY=False)
    portMatrix = portvals.as_matrix()
    rows = np.isfinite(portMatrix[:, 0])

    Allocs = np.zeros((orders_df.index[-1].days + 1, len(cos)))
    Cash = np.zeros(orders_df.index[-1].days + 1)
    Cash.fill(100000)
    leverage = 0;  # (sum(abs(all stock positions))) / (sum(all stock positions) + cash)
    stockVal = 0;
    for order in orders_df.iterrows():
        day = order[0].days
        sym = np.where(cos == order[1][0])[0][0]
        amt = order[1][2]
        if day == bad_index.days:
            continue
        if order[1][1] == 'BUY':
            Allocs[day][sym] += amt
            Cash[day:] -= amt * portMatrix[day][sym]
        else:
            Allocs[day][sym] -= amt
            Cash[day:] += amt * portMatrix[day][sym]

    Allocs = np.cumsum(Allocs, axis=0)
    norm_vals = np.sum(np.multiply(Allocs, portMatrix), axis=1);
    norm_vals = np.add(norm_vals, Cash)
    norm_vals = pd.DataFrame(data=norm_vals[rows], index=portvals.index[rows])
    return norm_vals
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号