main.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def train():
    model = build_stateful_lstm_model(BATCH_SIZE, TIME_STEP, INPUT_DIM, OUTPUT_DIM, dropout=0.1)

    # model.fit(x_train,y_train,validation_data=(x_train[:10],y_train[:10]),epochs=5,callbacks=[TensorBoard()],batch_size=1)

    for index, y_dat in enumerate(y):
        print('Run test on %s' % (index))
        model.fit(np.array([x[index]]), y_dat.reshape(1, 3),
                  validation_data=(np.array([x[index]]), y_dat.reshape(1, 3)), epochs=10, callbacks=[TensorBoard()])
        model.save(MODEL_PATH)
        x_pred = model.predict(np.array([x[index]]))
        print(x_pred)
        print(y_dat)

    model.save(MODEL_PATH)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号