train.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Jetson-RaceCar-AI 作者: ardamavi 项目源码 文件源码
def train_model(model, X_1, X_2, Y):

    batch_size = 1
    epochs = 10

    checkpoints = []
    if not os.path.exists('Data/Checkpoints/'):
        os.makedirs('Data/Checkpoints/')
    checkpoints.append(ModelCheckpoint('Data/Checkpoints/best_weights.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='auto', period=1))
    checkpoints.append(TensorBoard(log_dir='Data/Checkpoints/./logs', histogram_freq=0, write_graph=True, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None))

    model.fit([X_1, X_2], Y, batch_size=batch_size, epochs=epochs, validation_data=([X_1, X_2], Y), shuffle=True, callbacks=checkpoints)

    return model
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号