test_directed_timestep_LSTM.py 文件源码

python
阅读 60 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def train():
    model = build_real_stateful_lstm_model_with_normalization(BATCH_SIZE, TIME_STEP, INPUT_DIM, OUTPUT_DIM)

    # deal with x,y



    # x_train = x


    model.fit(x_train[:SAMPLE_NUM//BATCH_SIZE*BATCH_SIZE],
              y_train[:SAMPLE_NUM//BATCH_SIZE*BATCH_SIZE],
              batch_size=BATCH_SIZE,
              validation_split=0,
              epochs=30, callbacks=[TensorBoard(log_dir='./stateful_lstm_fixed')])

    # for index,y_dat in enumerate(y):
    #     print('Run test on %s' %(index))
    #     # print(y_dat.reshape(3,1))
    #     model.fit(np.array([x[index]]),np.array([y_dat.reshape(1,3)]),validation_data=(np.array([x[index]]),np.array([y_dat.reshape(1,3)])),epochs=100,callbacks=[TensorBoard()])
    #     model.save(MODEL_PATH)
    #     x_pred = model.predict(np.array([x[index]]))
    #     print(x_pred,x_pred.shape)
    #     print(np.array([y_dat.reshape(1,3)]))

    import random

    randomIndex = random.randint(0, SAMPLE_NUM)

    print('Selecting %s as the sample' % (randomIndex))

    pred = model.predict(x_train[randomIndex:randomIndex + 1])

    print(pred)

    print(y_train[randomIndex])

    model.save(MODEL_PATH)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号