def train_tas(model, model_scope, num_epoches, result_file):
height, width = FEATURE_HEIGHT, FEATURE_WIDTH
feats0, feats1 = read_features_tas(height, width)
y0 = np.zeros((feats0.shape[0], 1), dtype=np.float32)
y1 = np.ones((feats1.shape[0], 1), dtype=np.float32)
all_feats = np.append(feats0, feats1, axis=0)
all_y = np.append(y0, y1, axis=0)
print("all_feats shapes: toll = {}, closed = {}, all = {}; "
"and dtype = {}".format(feats0.shape, feats1.shape, all_feats.shape, all_feats.dtype))
print("all_y shape: {}; and dtype={}".format(all_y.shape, all_y.dtype))
res_dir = os.path.join(PROJECT_ROOT, 'Data', 'Result')
img_cnn = ImgConvNets(model, model_scope, height, width, class_count=2, keep_prob=0.5,
batch_size=32, learning_rate=1e-4, lr_adaptive=True, num_epoches=num_epoches)
img_cnn.train(all_feats, all_y, res_dir, result_file=result_file)
评论列表
文章目录