From 382158298bc09dbe387043b4bc31715fdd881d10 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Fri, 30 Sep 2022 01:43:38 +0000 Subject: [PATCH] Update default Tflite model OpResolver in BaseOptions. PiperOrigin-RevId: 477873299 --- mediapipe/tasks/cc/core/BUILD | 16 +++++++++ mediapipe/tasks/cc/core/base_options.h | 3 +- .../mediapipe_builtin_op_resolver.cc} | 12 +++---- .../mediapipe_builtin_op_resolver.h} | 18 +++++----- mediapipe/tasks/cc/vision/hand_detector/BUILD | 12 ------- .../hand_detector/hand_detector_graph_test.cc | 6 ++-- .../hand_detector_op_resolver.cc | 35 ------------------- .../hand_detector/hand_detector_op_resolver.h | 34 ------------------ .../tasks/cc/vision/image_segmenter/BUILD | 16 --------- .../vision/image_segmenter/image_segmenter.h | 1 - .../image_segmenter/image_segmenter_test.cc | 5 --- 11 files changed, 33 insertions(+), 125 deletions(-) rename mediapipe/tasks/cc/{vision/image_segmenter/image_segmenter_op_resolvers.cc => core/mediapipe_builtin_op_resolver.cc} (87%) rename mediapipe/tasks/cc/{vision/image_segmenter/image_segmenter_op_resolvers.h => core/mediapipe_builtin_op_resolver.h} (65%) delete mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc delete mode 100644 mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 38030c525..8d19227f1 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -23,6 +23,7 @@ cc_library( srcs = ["base_options.cc"], hdrs = ["base_options.h"], deps = [ + ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", @@ -50,6 +51,21 @@ cc_library( ], ) +cc_library( + name = "mediapipe_builtin_op_resolver", + srcs = ["mediapipe_builtin_op_resolver.cc"], + hdrs = ["mediapipe_builtin_op_resolver.h"], + deps = [ + "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", + "//mediapipe/util/tflite/operations:max_pool_argmax", + "//mediapipe/util/tflite/operations:max_unpooling", + "//mediapipe/util/tflite/operations:transform_landmarks", + "//mediapipe/util/tflite/operations:transform_tensor_bilinear", + "//mediapipe/util/tflite/operations:transpose_conv_bias", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + # TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator # supports TFLite-in-GMSCore. cc_library( diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index 67a03385b..4717ea50e 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" @@ -63,7 +64,7 @@ struct BaseOptions { // A non-default OpResolver to support custom Ops or specify a subset of // built-in Ops. std::unique_ptr op_resolver = - absl::make_unique(); + absl::make_unique(); }; // Converts a BaseOptions to a BaseOptionsProto. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc similarity index 87% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index cd3b5690f..62898a005 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" @@ -21,14 +21,11 @@ limitations under the License. #include "mediapipe/util/tflite/operations/transform_landmarks.h" #include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h" #include "mediapipe/util/tflite/operations/transpose_conv_bias.h" -#include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() - : BuiltinOpResolver() { +namespace core { +MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { AddCustom("MaxPoolingWithArgmax2D", mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); AddCustom("MaxUnpooling2D", @@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), /*version=*/2); } - -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h similarity index 65% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h index a0538a674..a7c28aa71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h @@ -13,25 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ +#define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ #include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -class SelfieSegmentationModelOpResolver +namespace core { +class MediaPipeBuiltinOpResolver : public tflite::ops::builtin::BuiltinOpResolver { public: - SelfieSegmentationModelOpResolver(); - SelfieSegmentationModelOpResolver( - const SelfieSegmentationModelOpResolver& r) = delete; + MediaPipeBuiltinOpResolver(); + MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete; }; -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 23cf5f72d..c87cc50a6 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -18,18 +18,6 @@ package(default_visibility = [ licenses(["notice"]) -cc_library( - name = "hand_detector_op_resolver", - srcs = ["hand_detector_op_resolver.cc"], - hdrs = ["hand_detector_op_resolver.h"], - deps = [ - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - cc_library( name = "hand_detector_graph", srcs = ["hand_detector_graph.cc"], diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index a2fbd7c54..3fa97664e 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -35,11 +35,11 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" @@ -121,8 +121,8 @@ absl::StatusOr> CreateTaskRunner( hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> graph[Output>(kHandNormRectsTag)]; - return TaskRunner::Create(graph.GetConfig(), - absl::make_unique()); + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); } HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) { diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc deleted file mode 100644 index 262fb2c75..000000000 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" - -#include "mediapipe/util/tflite/operations/max_pool_argmax.h" -#include "mediapipe/util/tflite/operations/max_unpooling.h" -#include "mediapipe/util/tflite/operations/transpose_conv_bias.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -HandDetectorOpResolver::HandDetectorOpResolver() { - AddCustom("MaxPoolingWithArgmax2D", - mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); - AddCustom("MaxUnpooling2D", - mediapipe::tflite_operations::RegisterMaxUnpooling2D()); - AddCustom("Convolution2DTransposeBias", - mediapipe::tflite_operations::RegisterConvolution2DTransposeBias()); -} -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h deleted file mode 100644 index a55661fa3..000000000 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ - -#include "tensorflow/lite/kernels/register.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver { - public: - HandDetectorOpResolver(); - HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete; -}; - -} // namespace vision -} // namespace tasks -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 6af733657..6bdbf41da 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -33,7 +33,6 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) @@ -73,19 +72,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "image_segmenter_op_resolvers", - srcs = ["image_segmenter_op_resolvers.cc"], - hdrs = ["image_segmenter_op_resolvers.h"], - deps = [ - "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transform_landmarks", - "//mediapipe/util/tflite/operations:transform_tensor_bilinear", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index ce9cb104c..e2734c4e4 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -26,7 +26,6 @@ limitations under the License. #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 2f1c26a79..1d3f3e786 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -260,8 +259,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::SOFTMAX; @@ -290,8 +287,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter,