main_multi_cnn.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def train():
    model = build_multi_1d_cnn_model(BATCH_SIZE,
                                     TIME_STEP,
                                     INPUT_DIM,
                                     OUTPUT_DIM,
                                     dropout=0.4,
                                     kernel_size=3,
                                     pooling_size=2,
                                     conv_dim=(128, 64, 32),
                                     stack_loop_num=2)

    # deal with x,y



    # x_train = x


    model.fit(x_train, y_train, validation_split=0, epochs=50, callbacks=[TensorBoard(log_dir='./cnn_dir')], batch_size=10)

    # for index,y_dat in enumerate(y):
    #     print('Run test on %s' %(index))
    #     # print(y_dat.reshape(3,1))
    #     model.fit(np.array([x[index]]),np.array([y_dat.reshape(1,3)]),validation_data=(np.array([x[index]]),np.array([y_dat.reshape(1,3)])),epochs=100,callbacks=[TensorBoard()])
    #     model.save(MODEL_PATH)
    #     x_pred = model.predict(np.array([x[index]]))
    #     print(x_pred,x_pred.shape)
    #     print(np.array([y_dat.reshape(1,3)]))

    import random

    randomIndex = random.randint(0, SAMPLE_NUM)

    print('Selecting %s as the sample' % (randomIndex))

    pred = model.predict(x_train[randomIndex:randomIndex + 1])

    print(pred)

    print(y_train[randomIndex])

    model.save(MODEL_PATH)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号