main_residual_network.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:keras_detect_tool_wear 作者: kidozh 项目源码 文件源码
def train():
    model = build_main_residual_network(BATCH_SIZE,MAX_TIME_STEP,INPUT_DIM,OUTPUT_DIM,loop_depth=DEPTH)

    # deal with x,y



    # x_train = x


    model.fit(x_train, y_train, validation_split=0.1, epochs=50  , callbacks=[TensorBoard(log_dir='./residual_cnn_dir_deep_%s_all'%(DEPTH))])

    import random

    randomIndex = random.randint(0, SAMPLE_NUM)

    print('Selecting- %s as the sample' % (randomIndex))

    pred = model.predict(x_train[randomIndex:randomIndex + 1])

    print(pred)

    print(y_train[randomIndex])

    model.save(MODEL_PATH)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号