def test_conv3d(self):
keras_model = Sequential()
keras_model.add(Conv3D(8, (5, 5, 5), input_shape=(3, 8, 8, 8),
name='conv'))
keras_model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD())
pytorch_model = Conv3DNet()
self.transfer(keras_model, pytorch_model)
self.assertEqualPrediction(keras_model,
pytorch_model,
self.test_data_3d,
delta=1e-4)
评论列表
文章目录