def train_model(x_tr, y_tr, conv_f_n, dense_n):
save_name = '/home/nripesh/PycharmProjects/Siamese/siamese_supervised/shape_match_model_endo_k3_new.h5'
tr_epoch = 10
input_dim = x_tr.shape[2:]
input_a = Input(shape=input_dim)
input_b = Input(shape=input_dim)
base_network = create_cnn_network(input_dim, conv_f_n, dense_n)
processed_a = base_network(input_a)
processed_b = base_network(input_b)
distance = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape)([processed_a, processed_b])
model_tr = Model(inputs=[input_a, input_b], outputs=distance)
# train
# opt_func = RMSprop(lr=.0005, clipnorm=1)
opt_func = RMSprop()
model_tr.compile(loss=contrastive_loss, optimizer=opt_func)
model_tr.fit([x_tr[:, 0], x_tr[:, 1]], y_tr, validation_split=.30,
batch_size=128, verbose=2, epochs=tr_epoch, callbacks=[EarlyStopping(monitor='val_loss', patience=2)])
model_tr.save(save_name)
return model_tr
# test, also provide info on which pair it was trained on and which it was tested on
评论列表
文章目录