def train_models_feat(data, targets, groups, batch_size=512, epochs=250, epochs_to_stop=15):
"""
trains a ann and rnn model with features
the given data with 20% validation set and returns the two models
"""
batch_size = 512
input_shape = list((np.array(data[0])).shape) #train_data.shape
n_classes = targets.shape[1]
train_idx, val_idx = GroupKFold(5).split(groups,groups,groups).__next__()
train_data = [data[i] for i in train_idx]
train_target = targets[train_idx]
train_groups = groups[train_idx]
val_data = [data[i] for i in val_idx]
val_target = targets[val_idx]
val_groups = groups[val_idx]
model = models.ann(input_shape, n_classes)
g_train= generator(train_data, train_target, batch_size, val=False)
g_val = generator(val_data, val_target, batch_size, val=True)
cb = Checkpoint_balanced(g_val, verbose=1, groups=val_groups,
epochs_to_stop=epochs_to_stop, plot = True, name = '{}, {}'.format(model.name, 'testing'))
model.fit_generator(g_train, g_train.n_batches, epochs=epochs, callbacks=[cb], max_queue_size=1, verbose=0)
val_acc = cb.best_acc
val_f1 = cb.best_f1
print('CNN Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))
# LSTM training
batch_size = 512
n_classes = targets.shape[1]
train_idx, val_idx = GroupKFold(5).split(groups,groups,groups).__next__()
train_data = np.array([data[i] for i in train_idx])
train_target = targets[train_idx]
train_groups = groups[train_idx]
val_data = np.array([data[i] for i in val_idx])
val_target = targets[val_idx]
val_groups = groups[val_idx]
train_data_seq, train_target_seq, train_groups_seq = tools.to_sequences(train_data, train_target, groups=train_groups, seqlen=6)
val_data_seq, val_target_seq, val_groups_seq = tools.to_sequences(val_data, val_target, groups=val_groups, seqlen=6)
input_shape = list((np.array(train_data_seq[0])).shape) #train_data.shape
print(input_shape)
rnn_model = models.pure_rnn_do(input_shape, n_classes)
g_train = generator(train_data_seq, train_target_seq, batch_size, val=False)
g_val = generator(val_data_seq, val_target_seq, batch_size, val=True)
cb = Checkpoint_balanced(g_val, verbose=1, groups=val_groups_seq,
epochs_to_stop=epochs_to_stop, plot = True, name = '{}, {}'.format(rnn_model.name, 'testing'))
rnn_model.fit_generator(g_train, g_train.n_batches, epochs=epochs, callbacks=[cb], max_queue_size=1, verbose=0)
val_acc = cb.best_acc
val_f1 = cb.best_f1
print('CNN Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))
return model, rnn_model
评论列表
文章目录