train.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:Cat-Segmentation 作者: ardamavi 项目源码 文件源码
def train_model(model, X, X_test, Y, Y_test):
    if not os.path.exists('Data/Checkpoints/'):
        os.makedirs('Data/Checkpoints/')
    checkpoints = []
    checkpoints.append(ModelCheckpoint('Data/Checkpoints/best_weights.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto', period=1))
    checkpoints.append(TensorBoard(log_dir='Data/Checkpoints/./logs', histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None))

    model.fit(X, Y, batch_size=batch_size, epochs=epochs, validation_data=(X_test, Y_test), shuffle=True, callbacks=checkpoints)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号