Update default Tflite model OpResolver in BaseOptions.

PiperOrigin-RevId: 477873299
This commit is contained in:
MediaPipe Team 2022-09-30 01:43:38 +00:00 committed by Sebastian Schmidt
parent 8af4cca413
commit 382158298b
11 changed files with 33 additions and 125 deletions

View File

@ -23,6 +23,7 @@ cc_library(
srcs = ["base_options.cc"],
hdrs = ["base_options.h"],
deps = [
":mediapipe_builtin_op_resolver",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
@ -50,6 +51,21 @@ cc_library(
],
)
cc_library(
name = "mediapipe_builtin_op_resolver",
srcs = ["mediapipe_builtin_op_resolver.cc"],
hdrs = ["mediapipe_builtin_op_resolver.h"],
deps = [
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
"//mediapipe/util/tflite/operations:max_pool_argmax",
"//mediapipe/util/tflite/operations:max_unpooling",
"//mediapipe/util/tflite/operations:transform_landmarks",
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
"//mediapipe/util/tflite/operations:transpose_conv_bias",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
# supports TFLite-in-GMSCore.
cc_library(

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include "absl/memory/memory.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h"
@ -63,7 +64,7 @@ struct BaseOptions {
// A non-default OpResolver to support custom Ops or specify a subset of
// built-in Ops.
std::unique_ptr<tflite::OpResolver> op_resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
absl::make_unique<MediaPipeBuiltinOpResolver>();
};
// Converts a BaseOptions to a BaseOptionsProto.

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
@ -21,14 +21,11 @@ limitations under the License.
#include "mediapipe/util/tflite/operations/transform_landmarks.h"
#include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h"
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {
namespace tasks {
namespace vision {
SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
: BuiltinOpResolver() {
namespace core {
MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
AddCustom("MaxPoolingWithArgmax2D",
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D",
@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
/*version=*/2);
}
} // namespace vision
} // namespace core
} // namespace tasks
} // namespace mediapipe

View File

@ -13,25 +13,23 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
#ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
#define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {
namespace tasks {
namespace vision {
class SelfieSegmentationModelOpResolver
namespace core {
class MediaPipeBuiltinOpResolver
: public tflite::ops::builtin::BuiltinOpResolver {
public:
SelfieSegmentationModelOpResolver();
SelfieSegmentationModelOpResolver(
const SelfieSegmentationModelOpResolver& r) = delete;
MediaPipeBuiltinOpResolver();
MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete;
};
} // namespace vision
} // namespace core
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
#endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_

View File

@ -18,18 +18,6 @@ package(default_visibility = [
licenses(["notice"])
cc_library(
name = "hand_detector_op_resolver",
srcs = ["hand_detector_op_resolver.cc"],
hdrs = ["hand_detector_op_resolver.h"],
deps = [
"//mediapipe/util/tflite/operations:max_pool_argmax",
"//mediapipe/util/tflite/operations:max_unpooling",
"//mediapipe/util/tflite/operations:transpose_conv_bias",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
cc_library(
name = "hand_detector_graph",
srcs = ["hand_detector_graph.cc"],

View File

@ -35,11 +35,11 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
@ -121,8 +121,8 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >>
graph[Output<std::vector<NormalizedRect>>(kHandNormRectsTag)];
return TaskRunner::Create(graph.GetConfig(),
absl::make_unique<HandDetectorOpResolver>());
return TaskRunner::Create(
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
}
HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) {

View File

@ -1,35 +0,0 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
#include "mediapipe/util/tflite/operations/max_unpooling.h"
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
namespace mediapipe {
namespace tasks {
namespace vision {
HandDetectorOpResolver::HandDetectorOpResolver() {
AddCustom("MaxPoolingWithArgmax2D",
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D",
mediapipe::tflite_operations::RegisterMaxUnpooling2D());
AddCustom("Convolution2DTransposeBias",
mediapipe::tflite_operations::RegisterConvolution2DTransposeBias());
}
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -1,34 +0,0 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {
namespace tasks {
namespace vision {
class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver {
public:
HandDetectorOpResolver();
HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete;
};
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_

View File

@ -33,7 +33,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
@ -73,19 +72,4 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "image_segmenter_op_resolvers",
srcs = ["image_segmenter_op_resolvers.cc"],
hdrs = ["image_segmenter_op_resolvers.h"],
deps = [
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
"//mediapipe/util/tflite/operations:max_pool_argmax",
"//mediapipe/util/tflite/operations:max_unpooling",
"//mediapipe/util/tflite/operations:transform_landmarks",
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
"//mediapipe/util/tflite/operations:transpose_conv_bias",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
# TODO: This test fails in OSS

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe {

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -260,8 +259,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
options->base_options.op_resolver =
absl::make_unique<SelfieSegmentationModelOpResolver>();
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
@ -290,8 +287,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
options->base_options.op_resolver =
absl::make_unique<SelfieSegmentationModelOpResolver>();
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::NONE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,