keras_utils.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:AutoSleepScorerDev 作者: skjerns 项目源码 文件源码
def train_models_feat(data, targets, groups, batch_size=512, epochs=250, epochs_to_stop=15):
    """
    trains a ann and rnn model with features
    the given data with 20% validation set and returns the two models
    """
    batch_size = 512
    input_shape = list((np.array(data[0])).shape) #train_data.shape
    n_classes = targets.shape[1]
    train_idx, val_idx = GroupKFold(5).split(groups,groups,groups).__next__()
    train_data   = [data[i] for i in train_idx]
    train_target = targets[train_idx]
    train_groups = groups[train_idx]
    val_data     = [data[i] for i in val_idx]
    val_target   = targets[val_idx]
    val_groups   = groups[val_idx]
    model = models.ann(input_shape, n_classes)
    g_train= generator(train_data, train_target, batch_size, val=False)
    g_val  = generator(val_data, val_target, batch_size, val=True)
    cb  = Checkpoint_balanced(g_val, verbose=1, groups=val_groups,
                              epochs_to_stop=epochs_to_stop, plot = True, name = '{}, {}'.format(model.name, 'testing'))
    model.fit_generator(g_train, g_train.n_batches, epochs=epochs, callbacks=[cb], max_queue_size=1, verbose=0)
    val_acc = cb.best_acc
    val_f1  = cb.best_f1
    print('CNN Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))

    # LSTM training
    batch_size = 512
    n_classes = targets.shape[1]
    train_idx, val_idx = GroupKFold(5).split(groups,groups,groups).__next__()
    train_data   = np.array([data[i] for i in train_idx])
    train_target = targets[train_idx]
    train_groups = groups[train_idx]
    val_data     = np.array([data[i] for i in val_idx])
    val_target   = targets[val_idx]
    val_groups   = groups[val_idx]

    train_data_seq, train_target_seq, train_groups_seq = tools.to_sequences(train_data, train_target, groups=train_groups, seqlen=6)
    val_data_seq, val_target_seq, val_groups_seq        = tools.to_sequences(val_data, val_target, groups=val_groups, seqlen=6)

    input_shape = list((np.array(train_data_seq[0])).shape) #train_data.shape
    print(input_shape)
    rnn_model = models.pure_rnn_do(input_shape, n_classes)

    g_train = generator(train_data_seq, train_target_seq, batch_size, val=False)
    g_val   = generator(val_data_seq, val_target_seq, batch_size, val=True)
    cb  = Checkpoint_balanced(g_val, verbose=1, groups=val_groups_seq,
                              epochs_to_stop=epochs_to_stop, plot = True, name = '{}, {}'.format(rnn_model.name, 'testing'))
    rnn_model.fit_generator(g_train, g_train.n_batches, epochs=epochs, callbacks=[cb], max_queue_size=1, verbose=0)
    val_acc = cb.best_acc
    val_f1  = cb.best_f1
    print('CNN Val acc: {:.1f}, Val F1: {:.1f}'.format(val_acc*100, val_f1*100))



    return model, rnn_model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号