def train_lss(model, model_scope, num_epoches, result_file):
height, width = FEATURE_HEIGHT, FEATURE_WIDTH
feats0, feats1, feats2, feats3 = read_features_lss(height, width)
y0 = np.zeros((feats0.shape[0], 1), dtype=np.float32)
y1 = np.ones((feats1.shape[0], 1), dtype=np.float32)
y2 = np.ones((feats2.shape[0], 1), dtype=np.float32) * 2
y3 = np.ones((feats3.shape[0], 1), dtype=np.float32) * 3
all_feats = np.append(np.append(np.append(feats0, feats1, axis=0), feats2, axis=0),
feats3, axis=0)
all_y = np.append(np.append(np.append(y0, y1, axis=0), y2, axis=0), y3, axis=0)
print("all_feats shapes: zero toll = {}, closed = {}, normal = {}, congested = {}, all = {}; "
"and dtype = {}".format(feats0.shape, feats1.shape, feats2.shape, feats3.shape,
all_feats.shape, all_feats.dtype))
print("all_y shape: {}; and dtype={}".format(all_y.shape, all_y.dtype))
res_dir = os.path.join(PROJECT_ROOT, 'Data', 'Result')
img_cnn = ImgConvNets(model, model_scope, height, width, class_count=4, keep_prob=0.5,
batch_size=32, learning_rate=1e-4, lr_adaptive=True, num_epoches=num_epoches)
img_cnn.train(all_feats, all_y, res_dir, result_file=result_file)
评论列表
文章目录