rnn.py 文件源码

python
阅读 486 收藏 0 点赞 0 评论 0

项目:smiles-neural-network 作者: PMitura 项目源码 文件源码
def train(model, nnInput, labels, validation, makePlot = True,
        labelIndexes = RP['label_idxs']):
    print('  Training model...')


    # needed format is orthogonal to ours
    '''
    formattedLabels = np.zeros((len(labels[0]), len(labelIndexes)))
    formattedValid = np.zeros((len(validation[1][labelIndexes[0]]),
        len(labelIndexes)))
    for i in range(len(labelIndexes)):
        for j in range(len(labels[0])):
            formattedLabels[j][i] = labels[labelIndexes[i]][j]
        for j in range(len(validation[1][labelIndexes[i]])):
            formattedValid[j][i] = validation[1][labelIndexes[i]][j]
    '''
    early = keras.callbacks.EarlyStopping(monitor = 'val_loss',
            patience = RP['early_stop'])

    learningRateScheduler = keras.callbacks.LearningRateScheduler(learningRateDecayer)

    modelLogger = visualization.ModelLogger()

    history = model.fit(nnInput, labels, nb_epoch = RP['epochs'],
            batch_size = RP['batch'], callbacks = [early],
            validation_data = (validation[0], validation[1]))

    if makePlot:
        values = np.zeros((len(history.history['loss']), 2))
        for i in range(len(history.history['loss'])):
            values[i][0] = history.history['loss'][i]
            values[i][1] = history.history['val_loss'][i]
        utility.plotLoss(values)

    visualization.histograms(modelLogger)

    print('    Model weights:')
    print(model.summary())
    # print(model.get_weights())
    print('  ...done')
    return len(history.history['loss'])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号