def train_models(data, targets, groups,model=None, cropsize=2800, batch_size=512, epochs=250, epochs_to_stop=15,rnn_epochs_to_stop=15):
"""
trains a cnn3adam_filter_l2 model with a LSTM on top on
the given data with 20% validation set and returns the two models
"""
input_shape = list((np.array(data[0])).shape) #train_data.shape
input_shape[0] = cropsize
n_classes = targets.shape[1]
train_idx, val_idx = GroupKFold(5).split(groups,groups,groups).__next__()
train_data = [data[i] for i in train_idx]
train_target = targets[train_idx]
train_groups = groups[train_idx]
val_data = [data[i] for i in val_idx]
val_target = targets[val_idx]
val_groups = groups[val_idx]
model = models.cnn3adam_filter_l2(input_shape, n_classes) if model is None else model(input_shape, n_classes)
g_train= generator(train_data, train_target, batch_size, val=False, cropsize=cropsize)
g_val = generator(val_data, val_target, batch_size, val=True, cropsize=cropsize)
cb = Checkpoint_balanced(g_val, verbose=1, groups=val_groups,
epochs_to_stop=epochs_to_stop, plot = True, name = '{}, {}'.format(model.name, 'testing'))
model.fit_generator(g_train, g_train.n_batches, epochs=epochs, callbacks=[cb], max_queue_size=1, verbose=0)
val_acc = cb.best_acc
val_f1 = cb.best_f1
print('CNN Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))
# LSTM training
rnn_modelfun = models.pure_rnn_do
lname = 'fc1'
seq = 6
rnn_epochs = epochs
stopafter_rnn = rnn_epochs_to_stop
features = get_activations(model, train_data + val_data, lname, batch_size*2, cropsize=cropsize)
train_data_extracted = features[0:len(train_data)]
val_data_extracted = features[len(train_data):]
assert (len(train_data)==len(train_data_extracted)) and (len(val_data)==len(val_data_extracted))
train_data_seq, train_target_seq, train_groups_seq = tools.to_sequences(train_data_extracted, train_target,groups=train_groups, seqlen=seq)
val_data_seq, val_target_seq, val_groups_seq = tools.to_sequences(val_data_extracted, val_target, groups=val_groups, seqlen=seq)
rnn_shape = list((np.array(train_data_seq[0])).shape)
neurons = int(np.sqrt(rnn_shape[-1])*4)
rnn_model = rnn_modelfun(rnn_shape, n_classes, layers=2, neurons=neurons, dropout=0.3)
print('Starting RNN model with input from layer fc1: {} at {}'.format(rnn_model.name, rnn_shape, time.ctime()))
g_train= generator(train_data_seq, train_target_seq, batch_size, val=False)
g_val = generator(val_data_seq, val_target_seq, batch_size, val=True)
cb = Checkpoint_balanced(g_val, verbose=1, groups = val_groups_seq,
epochs_to_stop=stopafter_rnn, plot = True, name = '{}, {}'.format(rnn_model.name,'fc1'))
rnn_model.fit_generator(g_train, g_train.n_batches, epochs=rnn_epochs, verbose=0, callbacks=[cb],max_queue_size=1)
val_acc = cb.best_acc
val_f1 = cb.best_f1
print('LSTM Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))
return model, rnn_model
评论列表
文章目录