s2train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DmsMsgRcg 作者: bshao001 项目源码 文件源码
def train_lss(model, model_scope, num_epoches, result_file):
    height, width = FEATURE_HEIGHT, FEATURE_WIDTH

    feats0, feats1, feats2, feats3 = read_features_lss(height, width)

    y0 = np.zeros((feats0.shape[0], 1), dtype=np.float32)
    y1 = np.ones((feats1.shape[0], 1), dtype=np.float32)
    y2 = np.ones((feats2.shape[0], 1), dtype=np.float32) * 2
    y3 = np.ones((feats3.shape[0], 1), dtype=np.float32) * 3

    all_feats = np.append(np.append(np.append(feats0, feats1, axis=0), feats2, axis=0),
                          feats3, axis=0)
    all_y = np.append(np.append(np.append(y0, y1, axis=0), y2, axis=0), y3, axis=0)

    print("all_feats shapes: zero toll = {}, closed = {}, normal = {}, congested = {},  all = {}; "
          "and dtype = {}".format(feats0.shape, feats1.shape, feats2.shape, feats3.shape,
                                  all_feats.shape, all_feats.dtype))
    print("all_y shape: {}; and dtype={}".format(all_y.shape, all_y.dtype))

    res_dir = os.path.join(PROJECT_ROOT, 'Data', 'Result')
    img_cnn = ImgConvNets(model, model_scope, height, width, class_count=4, keep_prob=0.5,
                          batch_size=32, learning_rate=1e-4, lr_adaptive=True, num_epoches=num_epoches)

    img_cnn.train(all_feats, all_y, res_dir, result_file=result_file)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号