keras_utils.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:AutoSleepScorerDev 作者: skjerns 项目源码 文件源码
def train_models(data, targets, groups,model=None, cropsize=2800, batch_size=512, epochs=250, epochs_to_stop=15,rnn_epochs_to_stop=15):
    """
    trains a cnn3adam_filter_l2 model with a LSTM on top on 
    the given data with 20% validation set and returns the two models
    """
    input_shape = list((np.array(data[0])).shape) #train_data.shape
    input_shape[0] = cropsize
    n_classes = targets.shape[1]
    train_idx, val_idx = GroupKFold(5).split(groups,groups,groups).__next__()
    train_data   = [data[i] for i in train_idx]
    train_target = targets[train_idx]
    train_groups = groups[train_idx]
    val_data     = [data[i] for i in val_idx]
    val_target   = targets[val_idx]
    val_groups   = groups[val_idx]
    model = models.cnn3adam_filter_l2(input_shape, n_classes) if model is None else model(input_shape, n_classes)
    g_train= generator(train_data, train_target, batch_size, val=False, cropsize=cropsize)
    g_val  = generator(val_data, val_target, batch_size, val=True, cropsize=cropsize)
    cb  = Checkpoint_balanced(g_val, verbose=1, groups=val_groups,
                              epochs_to_stop=epochs_to_stop, plot = True, name = '{}, {}'.format(model.name, 'testing'))
    model.fit_generator(g_train, g_train.n_batches, epochs=epochs, callbacks=[cb], max_queue_size=1, verbose=0)
    val_acc = cb.best_acc
    val_f1  = cb.best_f1
    print('CNN Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))

    # LSTM training
    rnn_modelfun = models.pure_rnn_do
    lname = 'fc1'
    seq = 6
    rnn_epochs = epochs
    stopafter_rnn = rnn_epochs_to_stop
    features = get_activations(model, train_data + val_data, lname, batch_size*2, cropsize=cropsize)
    train_data_extracted = features[0:len(train_data)]
    val_data_extracted   = features[len(train_data):]
    assert (len(train_data)==len(train_data_extracted)) and (len(val_data)==len(val_data_extracted))
    train_data_seq, train_target_seq, train_groups_seq = tools.to_sequences(train_data_extracted, train_target,groups=train_groups, seqlen=seq)
    val_data_seq, val_target_seq, val_groups_seq       = tools.to_sequences(val_data_extracted,   val_target,  groups=val_groups, seqlen=seq)
    rnn_shape  = list((np.array(train_data_seq[0])).shape)
    neurons = int(np.sqrt(rnn_shape[-1])*4)
    rnn_model  = rnn_modelfun(rnn_shape, n_classes, layers=2, neurons=neurons, dropout=0.3)
    print('Starting RNN model with input from layer fc1: {} at {}'.format(rnn_model.name, rnn_shape, time.ctime()))
    g_train= generator(train_data_seq, train_target_seq, batch_size, val=False)
    g_val  = generator(val_data_seq, val_target_seq, batch_size, val=True)
    cb = Checkpoint_balanced(g_val, verbose=1, groups = val_groups_seq, 
                             epochs_to_stop=stopafter_rnn, plot = True, name = '{}, {}'.format(rnn_model.name,'fc1'))         
    rnn_model.fit_generator(g_train, g_train.n_batches, epochs=rnn_epochs, verbose=0, callbacks=[cb],max_queue_size=1)    
    val_acc = cb.best_acc
    val_f1  = cb.best_f1
    print('LSTM Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))

    return model, rnn_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号