Update default Tflite model OpResolver in BaseOptions.
PiperOrigin-RevId: 477873299
This commit is contained in:
parent
8af4cca413
commit
382158298b
|
@ -23,6 +23,7 @@ cc_library(
|
|||
srcs = ["base_options.cc"],
|
||||
hdrs = ["base_options.h"],
|
||||
deps = [
|
||||
":mediapipe_builtin_op_resolver",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
|
||||
|
@ -50,6 +51,21 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mediapipe_builtin_op_resolver",
|
||||
srcs = ["mediapipe_builtin_op_resolver.cc"],
|
||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
"//mediapipe/util/tflite/operations:transform_landmarks",
|
||||
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
|
||||
"//mediapipe/util/tflite/operations:transpose_conv_bias",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
|
||||
# supports TFLite-in-GMSCore.
|
||||
cc_library(
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
@ -63,7 +64,7 @@ struct BaseOptions {
|
|||
// A non-default OpResolver to support custom Ops or specify a subset of
|
||||
// built-in Ops.
|
||||
std::unique_ptr<tflite::OpResolver> op_resolver =
|
||||
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
absl::make_unique<MediaPipeBuiltinOpResolver>();
|
||||
};
|
||||
|
||||
// Converts a BaseOptions to a BaseOptionsProto.
|
||||
|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||
|
@ -21,14 +21,11 @@ limitations under the License.
|
|||
#include "mediapipe/util/tflite/operations/transform_landmarks.h"
|
||||
#include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h"
|
||||
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
|
||||
: BuiltinOpResolver() {
|
||||
namespace core {
|
||||
MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
||||
AddCustom("MaxPoolingWithArgmax2D",
|
||||
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
|
||||
AddCustom("MaxUnpooling2D",
|
||||
|
@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
|
|||
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
|
||||
/*version=*/2);
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace core
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,25 +13,23 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
|
||||
class SelfieSegmentationModelOpResolver
|
||||
namespace core {
|
||||
class MediaPipeBuiltinOpResolver
|
||||
: public tflite::ops::builtin::BuiltinOpResolver {
|
||||
public:
|
||||
SelfieSegmentationModelOpResolver();
|
||||
SelfieSegmentationModelOpResolver(
|
||||
const SelfieSegmentationModelOpResolver& r) = delete;
|
||||
MediaPipeBuiltinOpResolver();
|
||||
MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace core
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
|
|
@ -18,18 +18,6 @@ package(default_visibility = [
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "hand_detector_op_resolver",
|
||||
srcs = ["hand_detector_op_resolver.cc"],
|
||||
hdrs = ["hand_detector_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
"//mediapipe/util/tflite/operations:transpose_conv_bias",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_detector_graph",
|
||||
srcs = ["hand_detector_graph.cc"],
|
||||
|
|
|
@ -35,11 +35,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
|
@ -121,8 +121,8 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
|||
hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kHandNormRectsTag)];
|
||||
|
||||
return TaskRunner::Create(graph.GetConfig(),
|
||||
absl::make_unique<HandDetectorOpResolver>());
|
||||
return TaskRunner::Create(
|
||||
graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
|
||||
}
|
||||
|
||||
HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) {
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
|
||||
|
||||
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
|
||||
#include "mediapipe/util/tflite/operations/max_unpooling.h"
|
||||
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
HandDetectorOpResolver::HandDetectorOpResolver() {
|
||||
AddCustom("MaxPoolingWithArgmax2D",
|
||||
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
|
||||
AddCustom("MaxUnpooling2D",
|
||||
mediapipe::tflite_operations::RegisterMaxUnpooling2D());
|
||||
AddCustom("Convolution2DTransposeBias",
|
||||
mediapipe::tflite_operations::RegisterConvolution2DTransposeBias());
|
||||
}
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -1,34 +0,0 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver {
|
||||
public:
|
||||
HandDetectorOpResolver();
|
||||
HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
|
|
@ -33,7 +33,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
@ -73,19 +72,4 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image_segmenter_op_resolvers",
|
||||
srcs = ["image_segmenter_op_resolvers.cc"],
|
||||
hdrs = ["image_segmenter_op_resolvers.h"],
|
||||
deps = [
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
"//mediapipe/util/tflite/operations:max_pool_argmax",
|
||||
"//mediapipe/util/tflite/operations:max_unpooling",
|
||||
"//mediapipe/util/tflite/operations:transform_landmarks",
|
||||
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
|
||||
"//mediapipe/util/tflite/operations:transpose_conv_bias",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: This test fails in OSS
|
||||
|
|
|
@ -26,7 +26,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
|
|
@ -31,7 +31,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
@ -260,8 +259,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
|
|||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
|
||||
options->base_options.op_resolver =
|
||||
absl::make_unique<SelfieSegmentationModelOpResolver>();
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
|
||||
|
||||
|
@ -290,8 +287,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
|
|||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
|
||||
options->base_options.op_resolver =
|
||||
absl::make_unique<SelfieSegmentationModelOpResolver>();
|
||||
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
|
||||
options->activation = ImageSegmenterOptions::Activation::NONE;
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,
|
||||
|
|
Loading…
Reference in New Issue
Block a user