actor2.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:student_simulator_policy 作者: kolchinski 项目源码 文件源码
def main(_):
    print "Testing actor"
    topics, answers, masks, seq_lens, rewards = fake_sequences(20000, 3)
    actor = Actor(3, HIDDEN_SIZE, MAX_LENGTH)
    #embed()


    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        #print topics
        for i in range(200):
            s,e = 100*i, 100*(i+1)
            obj = actor.train_on_batch(session, rewards[s:e], seq_lens[s:e], masks[s:e], answers[s:e], topics[s:e])
            print obj
        actions = actor.test_on_batch(session, rewards[s:e], seq_lens[s:e], masks[s:e], answers[s:e], topics[s:e])
        actionsArray = np.array(actions[0])
        zerosByTime = np.sum(actionsArray == 0, axis=0)
        onesByTime = np.sum(actionsArray == 1, axis=0)
        twosByTime =  np.sum(actionsArray == 2, axis=0)
        avgZeroPos = np.sum(np.arange(50) * zerosByTime[:50]) / np.sum(zerosByTime[:50])
        avgOnePos = np.sum(np.arange(50) * onesByTime[:50]) / np.sum(onesByTime[:50])
        avgTwoPos = np.sum(np.arange(50) * twosByTime[:50]) / np.sum(twosByTime[:50])
        print avgZeroPos, avgOnePos, avgTwoPos
        embed()
        #print actions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号