MIDS.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:MIDS 作者: freegraphics 项目源码 文件源码
def train_rates():
    consts = Consts()
    rng = numpy.random.RandomState()
    theano_rng = RandomStreams(rng.randint(2 ** 30))
    rs = RecommenderSystem(rng= rng,theano_rng = theano_rng,consts=consts)
    validate_loss_min = 0
    validate_loss = 0
    for i in numpy.arange(100000):
        lt = time.time()
        for j in numpy.arange(consts.ids_move_count):
            loss_rates = rs.train_rates(learning_rate = consts.result_learning_rate)
            t1 = time.time()
            if t1>lt+1:
                sys.stdout.write("\t\t\t\t\t\t\t\t\t\r")
                sys.stdout.write("[%d] loss = %f , val = %f valmin = %f\r" % (i,loss_rates,validate_loss,validate_loss_min))
                lt = lt+1
        trace_rates(i + (consts.load_from_ids*consts.save_cycles),loss_rates,validate_loss_min,validate_loss,consts.trace_rates_file_name)
        if i % consts.save_cycles == 0:
            rs.save_rates((i/consts.save_cycles) + consts.load_from_ids,consts)
        if i % consts.validate_cycles == 0:
            validate_loss = rs.validate_rates(consts=consts)
            if validate_loss_min==0 or validate_loss<validate_loss_min:
                validate_loss_min = validate_loss
                rs.save_rates(0,consts)
        consts.update_index(i + (consts.load_from_ids*consts.save_cycles))

    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号