LeaveOutValidationEpiModel.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:Siamese 作者: ascourge21 项目源码 文件源码
def train_model(x_tr, y_tr, conv_f_n, dense_n):
    save_name = '/home/nripesh/PycharmProjects/Siamese/siamese_supervised/shape_match_model_epi_sx4.h5'
    tr_epoch = 20

    input_dim = x_tr.shape[2:]
    input_a = Input(shape=input_dim)
    input_b = Input(shape=input_dim)
    base_network = create_cnn_network(input_dim, conv_f_n, dense_n)
    processed_a = base_network(input_a)
    processed_b = base_network(input_b)

    distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([processed_a, processed_b])

    model_tr = Model(inputs=[input_a, input_b], outputs=distance)

    # train
    opt_func = RMSprop(lr=.003)
    model_tr.compile(loss=contrastive_loss, optimizer=opt_func)
    history = model_tr.fit([x_tr[:, 0], x_tr[:, 1]], y_tr, validation_split=.30,
                           batch_size=128, verbose=2, epochs=tr_epoch,
                           callbacks=[EarlyStopping(monitor='val_loss', patience=2)])

    # summarize history for loss
    plt.figure(1)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'val'], loc='upper left')
    # plt.show()
    plt.savefig('/home/nripesh/PycharmProjects/Siamese/siamese_supervised/epi_train_val_loss.png')
    plt.close(1)
    model_tr.save(save_name)
    return model_tr


# test, also provide info on which pair it was trained on and which it was tested on
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号