diff --git a/Dockerfile b/Dockerfile index 0b096fc56..462dacbd4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,6 +51,7 @@ RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-8 100 --slave /u RUN pip3 install --upgrade setuptools RUN pip3 install wheel RUN pip3 install future +RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1 RUN pip3 install six==1.14.0 RUN pip3 install tensorflow==2.2.0 RUN pip3 install tf_slim diff --git a/docs/index.md b/docs/index.md index 1532e10cc..e10952bcd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ title: Home nav_order: 1 --- -![MediaPipe](images/mediapipe_small.png) +![MediaPipe](https://mediapipe.dev/images/mediapipe_small.png) -------------------------------------------------------------------------------- @@ -13,21 +13,21 @@ nav_order: 1 [MediaPipe](https://google.github.io/mediapipe/) offers cross-platform, customizable ML solutions for live and streaming media. -![accelerated.png](images/accelerated_small.png) | ![cross_platform.png](images/cross_platform_small.png) +![accelerated.png](https://mediapipe.dev/images/accelerated_small.png) | ![cross_platform.png](https://mediapipe.dev/images/cross_platform_small.png) :------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------: ***End-to-End acceleration***: *Built-in fast ML inference and processing accelerated even on common hardware* | ***Build once, deploy anywhere***: *Unified solution works across Android, iOS, desktop/cloud, web and IoT* -![ready_to_use.png](images/ready_to_use_small.png) | ![open_source.png](images/open_source_small.png) +![ready_to_use.png](https://mediapipe.dev/images/ready_to_use_small.png) | ![open_source.png](https://mediapipe.dev/images/open_source_small.png) ***Ready-to-use solutions***: *Cutting-edge ML solutions demonstrating full power of the framework* | ***Free and open source***: *Framework and solutions both under Apache 2.0, fully extensible and customizable* ## ML solutions in MediaPipe Face Detection | Face Mesh | Iris | Hands | Pose | Holistic :----------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :------: -[![face_detection](images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](images/mobile/holistic_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/holistic) +[![face_detection](https://mediapipe.dev/images/mobile/face_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_detection) | [![face_mesh](https://mediapipe.dev/images/mobile/face_mesh_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/face_mesh) | [![iris](https://mediapipe.dev/images/mobile/iris_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/iris) | [![hand](https://mediapipe.dev/images/mobile/hand_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hands) | [![pose](https://mediapipe.dev/images/mobile/pose_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/pose) | [![hair_segmentation](https://mediapipe.dev/images/mobile/holistic_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/holistic) Hair Segmentation | Object Detection | Box Tracking | Instant Motion Tracking | Objectron | KNIFT :-------------------------------------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------: | :---: -[![hair_segmentation](images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) +[![hair_segmentation](https://mediapipe.dev/images/mobile/hair_segmentation_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/hair_segmentation) | [![object_detection](https://mediapipe.dev/images/mobile/object_detection_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/object_detection) | [![box_tracking](https://mediapipe.dev/images/mobile/object_tracking_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/box_tracking) | [![instant_motion_tracking](https://mediapipe.dev/images/mobile/instant_motion_tracking_android_small.gif)](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | [![objectron](https://mediapipe.dev/images/mobile/objectron_chair_android_gpu_small.gif)](https://google.github.io/mediapipe/solutions/objectron) | [![knift](https://mediapipe.dev/images/mobile/template_matching_android_cpu_small.gif)](https://google.github.io/mediapipe/solutions/knift) diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 99a698e4b..c378df7d0 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -216,6 +216,50 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "inference_runner", + hdrs = ["inference_runner.h"], + copts = select({ + # TODO: fix tensor.h not to require this, if possible + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "inference_interpreter_delegate_runner", + srcs = ["inference_interpreter_delegate_runner.cc"], + hdrs = ["inference_interpreter_delegate_runner.h"], + copts = select({ + # TODO: fix tensor.h not to require this, if possible + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":inference_runner", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/util/tflite:tflite_model_loader", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + cc_library( name = "inference_calculator_cpu", srcs = [ @@ -232,22 +276,64 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":inference_calculator_interface", + ":inference_calculator_utils", + ":inference_interpreter_delegate_runner", + ":inference_runner", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite/c:c_api_types", ] + select({ - "//conditions:default": [ - "//mediapipe/util:cpu_util", - ], - }) + select({ "//conditions:default": [], "//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"], }), alwayslink = 1, ) +cc_library( + name = "inference_calculator_utils", + srcs = ["inference_calculator_utils.cc"], + hdrs = ["inference_calculator_utils.h"], + deps = [ + ":inference_calculator_cc_proto", + "//mediapipe/framework:port", + ] + select({ + "//conditions:default": [ + "//mediapipe/util:cpu_util", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_xnnpack", + srcs = [ + "inference_calculator_xnnpack.cc", + ], + copts = select({ + # TODO: fix tensor.h not to require this, if possible + "//mediapipe:apple": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + ":inference_calculator_interface", + ":inference_calculator_utils", + ":inference_interpreter_delegate_runner", + ":inference_runner", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + ], + alwayslink = 1, +) + cc_library( name = "inference_calculator_gl_if_compute_shader_available", visibility = ["//visibility:public"], diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 365f9f082..4ccdc07e1 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -59,6 +59,7 @@ class InferenceCalculatorSelectorImpl } } impls.emplace_back("Cpu"); + impls.emplace_back("Xnnpack"); for (const auto& suffix : impls) { const auto impl = absl::StrCat("InferenceCalculator", suffix); if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue; diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index 8e1c32e48..5df5f993f 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -141,6 +141,10 @@ struct InferenceCalculatorCpu : public InferenceCalculator { static constexpr char kCalculatorName[] = "InferenceCalculatorCpu"; }; +struct InferenceCalculatorXnnpack : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorXnnpack"; +}; + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index f330f99c2..2e90c7cc9 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -18,78 +18,21 @@ #include #include -#include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/calculators/tensor/inference_calculator_utils.h" +#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h" +#include "mediapipe/calculators/tensor/inference_runner.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/interpreter_builder.h" #if defined(MEDIAPIPE_ANDROID) #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #endif // ANDROID - -#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) -#include "mediapipe/util/cpu_util.h" -#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ - -#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" namespace mediapipe { namespace api2 { -namespace { - -int GetXnnpackDefaultNumThreads() { -#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \ - defined(__EMSCRIPTEN_PTHREADS__) - constexpr int kMinNumThreadsByDefault = 1; - constexpr int kMaxNumThreadsByDefault = 4; - return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault, - kMaxNumThreadsByDefault); -#else - return 1; -#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__ -} - -// Returns number of threads to configure XNNPACK delegate with. -// Returns user provided value if specified. Otherwise, tries to choose optimal -// number of threads depending on the device. -int GetXnnpackNumThreads( - const bool opts_has_delegate, - const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) { - static constexpr int kDefaultNumThreads = -1; - if (opts_has_delegate && opts_delegate.has_xnnpack() && - opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) { - return opts_delegate.xnnpack().num_threads(); - } - return GetXnnpackDefaultNumThreads(); -} - -template -void CopyTensorBufferToInterpreter(const Tensor& input_tensor, - tflite::Interpreter* interpreter, - int input_tensor_index) { - auto input_tensor_view = input_tensor.GetCpuReadView(); - auto input_tensor_buffer = input_tensor_view.buffer(); - T* local_tensor_buffer = - interpreter->typed_input_tensor(input_tensor_index); - std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); -} - -template -void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, - int output_tensor_index, - Tensor* output_tensor) { - auto output_tensor_view = output_tensor->GetCpuWriteView(); - auto output_tensor_buffer = output_tensor_view.buffer(); - T* local_tensor_buffer = - interpreter->typed_output_tensor(output_tensor_index); - std::memcpy(output_tensor_buffer, local_tensor_buffer, - output_tensor->bytes()); -} - -} // namespace - class InferenceCalculatorCpuImpl : public NodeImpl { public: @@ -100,16 +43,11 @@ class InferenceCalculatorCpuImpl absl::Status Close(CalculatorContext* cc) override; private: - absl::Status InitInterpreter(CalculatorContext* cc); - absl::Status LoadDelegate(CalculatorContext* cc, - tflite::InterpreterBuilder* interpreter_builder); - absl::Status AllocateTensors(); + absl::StatusOr> CreateInferenceRunner( + CalculatorContext* cc); + absl::StatusOr MaybeCreateDelegate(CalculatorContext* cc); - // TfLite requires us to keep the model alive as long as the interpreter is. - Packet model_packet_; - std::unique_ptr interpreter_; - TfLiteDelegatePtr delegate_; - TfLiteType input_tensor_type_ = TfLiteType::kTfLiteNoType; + std::unique_ptr inference_runner_; }; absl::Status InferenceCalculatorCpuImpl::UpdateContract( @@ -122,7 +60,8 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract( } absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { - return InitInterpreter(cc); + ASSIGN_OR_RETURN(inference_runner_, CreateInferenceRunner(cc)); + return absl::OkStatus(); } absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { @@ -131,123 +70,32 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { } const auto& input_tensors = *kInTensors(cc); RET_CHECK(!input_tensors.empty()); - auto output_tensors = absl::make_unique>(); - if (input_tensor_type_ == kTfLiteNoType) { - input_tensor_type_ = interpreter_->tensor(interpreter_->inputs()[0])->type; - } - - // Read CPU input into tensors. - for (int i = 0; i < input_tensors.size(); ++i) { - switch (input_tensor_type_) { - case TfLiteType::kTfLiteFloat16: - case TfLiteType::kTfLiteFloat32: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); - break; - } - case TfLiteType::kTfLiteUInt8: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); - break; - } - case TfLiteType::kTfLiteInt8: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); - break; - } - case TfLiteType::kTfLiteInt32: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); - break; - } - default: - return absl::InvalidArgumentError( - absl::StrCat("Unsupported input tensor type:", input_tensor_type_)); - } - } - - // Run inference. - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); - - // Output result tensors (CPU). - const auto& tensor_indexes = interpreter_->outputs(); - output_tensors->reserve(tensor_indexes.size()); - for (int i = 0; i < tensor_indexes.size(); ++i) { - TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); - Tensor::Shape shape{std::vector{ - tensor->dims->data, tensor->dims->data + tensor->dims->size}}; - switch (tensor->type) { - case TfLiteType::kTfLiteFloat16: - case TfLiteType::kTfLiteFloat32: - output_tensors->emplace_back(Tensor::ElementType::kFloat32, shape); - CopyTensorBufferFromInterpreter(interpreter_.get(), i, - &output_tensors->back()); - break; - case TfLiteType::kTfLiteUInt8: - output_tensors->emplace_back( - Tensor::ElementType::kUInt8, shape, - Tensor::QuantizationParameters{tensor->params.scale, - tensor->params.zero_point}); - CopyTensorBufferFromInterpreter(interpreter_.get(), i, - &output_tensors->back()); - break; - case TfLiteType::kTfLiteInt8: - output_tensors->emplace_back( - Tensor::ElementType::kInt8, shape, - Tensor::QuantizationParameters{tensor->params.scale, - tensor->params.zero_point}); - CopyTensorBufferFromInterpreter(interpreter_.get(), i, - &output_tensors->back()); - break; - case TfLiteType::kTfLiteInt32: - output_tensors->emplace_back(Tensor::ElementType::kInt32, shape); - CopyTensorBufferFromInterpreter(interpreter_.get(), i, - &output_tensors->back()); - break; - default: - return absl::InvalidArgumentError( - absl::StrCat("Unsupported output tensor type:", - TfLiteTypeGetName(tensor->type))); - } - } + ASSIGN_OR_RETURN(std::vector output_tensors, + inference_runner_->Run(input_tensors)); kOutTensors(cc).Send(std::move(output_tensors)); return absl::OkStatus(); } absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) { - interpreter_ = nullptr; - delegate_ = nullptr; + inference_runner_ = nullptr; return absl::OkStatus(); } -absl::Status InferenceCalculatorCpuImpl::InitInterpreter( - CalculatorContext* cc) { - ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); - const auto& model = *model_packet_.Get(); +absl::StatusOr> +InferenceCalculatorCpuImpl::CreateInferenceRunner(CalculatorContext* cc) { + ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc)); ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); - const auto& op_resolver = op_resolver_packet.Get(); - tflite::InterpreterBuilder interpreter_builder(model, op_resolver); - MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder)); -#if defined(__EMSCRIPTEN__) - interpreter_builder.SetNumThreads(1); -#else - interpreter_builder.SetNumThreads( - cc->Options().cpu_num_thread()); -#endif // __EMSCRIPTEN__ - - RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk); - RET_CHECK(interpreter_); - return AllocateTensors(); + const int interpreter_num_threads = + cc->Options().cpu_num_thread(); + ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, MaybeCreateDelegate(cc)); + return CreateInferenceInterpreterDelegateRunner( + std::move(model_packet), std::move(op_resolver_packet), + std::move(delegate), interpreter_num_threads); } -absl::Status InferenceCalculatorCpuImpl::AllocateTensors() { - RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - return absl::OkStatus(); -} - -absl::Status InferenceCalculatorCpuImpl::LoadDelegate( - CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { +absl::StatusOr +InferenceCalculatorCpuImpl::MaybeCreateDelegate(CalculatorContext* cc) { const auto& calculator_opts = cc->Options(); auto opts_delegate = calculator_opts.delegate(); @@ -268,7 +116,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate( calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty(); if (opts_has_delegate && opts_delegate.has_tflite()) { // Default tflite inference requeqsted - no need to modify graph. - return absl::OkStatus(); + return nullptr; } #if defined(MEDIAPIPE_ANDROID) @@ -288,10 +136,8 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate( options.accelerator_name = nnapi.has_accelerator_name() ? nnapi.accelerator_name().c_str() : nullptr; - delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), - [](TfLiteDelegate*) {}); - interpreter_builder->AddDelegate(delegate_.get()); - return absl::OkStatus(); + return TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), + [](TfLiteDelegate*) {}); } #endif // MEDIAPIPE_ANDROID @@ -305,12 +151,11 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate( auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault(); xnnpack_opts.num_threads = GetXnnpackNumThreads(opts_has_delegate, opts_delegate); - delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), - &TfLiteXNNPackDelegateDelete); - interpreter_builder->AddDelegate(delegate_.get()); + return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), + &TfLiteXNNPackDelegateDelete); } - return absl::OkStatus(); + return nullptr; } } // namespace api2 diff --git a/mediapipe/calculators/tensor/inference_calculator_utils.cc b/mediapipe/calculators/tensor/inference_calculator_utils.cc new file mode 100644 index 000000000..11ded02bc --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_utils.cc @@ -0,0 +1,53 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/inference_calculator_utils.h" + +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/port.h" // NOLINT: provides MEDIAPIPE_ANDROID/IOS + +#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) +#include "mediapipe/util/cpu_util.h" +#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ + +namespace mediapipe { + +namespace { + +int GetXnnpackDefaultNumThreads() { +#if defined(MEDIAPIPE_ANDROID) || defined(MEDIAPIPE_IOS) || \ + defined(__EMSCRIPTEN_PTHREADS__) + constexpr int kMinNumThreadsByDefault = 1; + constexpr int kMaxNumThreadsByDefault = 4; + return std::clamp(NumCPUCores() / 2, kMinNumThreadsByDefault, + kMaxNumThreadsByDefault); +#else + return 1; +#endif // MEDIAPIPE_ANDROID || MEDIAPIPE_IOS || __EMSCRIPTEN_PTHREADS__ +} + +} // namespace + +int GetXnnpackNumThreads( + const bool opts_has_delegate, + const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate) { + static constexpr int kDefaultNumThreads = -1; + if (opts_has_delegate && opts_delegate.has_xnnpack() && + opts_delegate.xnnpack().num_threads() != kDefaultNumThreads) { + return opts_delegate.xnnpack().num_threads(); + } + return GetXnnpackDefaultNumThreads(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_utils.h b/mediapipe/calculators/tensor/inference_calculator_utils.h new file mode 100644 index 000000000..64b590e93 --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_utils.h @@ -0,0 +1,31 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_ + +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" + +namespace mediapipe { + +// Returns number of threads to configure XNNPACK delegate with. +// Returns user provided value if specified. Otherwise, tries to choose optimal +// number of threads depending on the device. +int GetXnnpackNumThreads( + const bool opts_has_delegate, + const mediapipe::InferenceCalculatorOptions::Delegate& opts_delegate); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_UTILS_H_ diff --git a/mediapipe/calculators/tensor/inference_calculator_xnnpack.cc b/mediapipe/calculators/tensor/inference_calculator_xnnpack.cc new file mode 100644 index 000000000..384c753ff --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_xnnpack.cc @@ -0,0 +1,122 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/calculators/tensor/inference_calculator_utils.h" +#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h" +#include "mediapipe/calculators/tensor/inference_runner.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" + +namespace mediapipe { +namespace api2 { + +class InferenceCalculatorXnnpackImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::StatusOr> CreateInferenceRunner( + CalculatorContext* cc); + absl::StatusOr CreateDelegate(CalculatorContext* cc); + + std::unique_ptr inference_runner_; +}; + +absl::Status InferenceCalculatorXnnpackImpl::UpdateContract( + CalculatorContract* cc) { + const auto& options = cc->Options(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorXnnpackImpl::Open(CalculatorContext* cc) { + ASSIGN_OR_RETURN(inference_runner_, CreateInferenceRunner(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorXnnpackImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + + ASSIGN_OR_RETURN(std::vector output_tensors, + inference_runner_->Run(input_tensors)); + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorXnnpackImpl::Close(CalculatorContext* cc) { + inference_runner_ = nullptr; + return absl::OkStatus(); +} + +absl::StatusOr> +InferenceCalculatorXnnpackImpl::CreateInferenceRunner(CalculatorContext* cc) { + ASSIGN_OR_RETURN(auto model_packet, GetModelAsPacket(cc)); + ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc)); + const int interpreter_num_threads = + cc->Options().cpu_num_thread(); + ASSIGN_OR_RETURN(TfLiteDelegatePtr delegate, CreateDelegate(cc)); + return CreateInferenceInterpreterDelegateRunner( + std::move(model_packet), std::move(op_resolver_packet), + std::move(delegate), interpreter_num_threads); +} + +absl::StatusOr +InferenceCalculatorXnnpackImpl::CreateDelegate(CalculatorContext* cc) { + const auto& calculator_opts = + cc->Options(); + auto opts_delegate = calculator_opts.delegate(); + if (!kDelegate(cc).IsEmpty()) { + const mediapipe::InferenceCalculatorOptions::Delegate& + input_side_packet_delegate = kDelegate(cc).Get(); + RET_CHECK( + input_side_packet_delegate.has_xnnpack() || + input_side_packet_delegate.delegate_case() == + mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET) + << "inference_calculator_cpu only supports delegate input side packet " + << "for TFLite, XNNPack"; + opts_delegate.MergeFrom(input_side_packet_delegate); + } + const bool opts_has_delegate = + calculator_opts.has_delegate() || !kDelegate(cc).IsEmpty(); + + auto xnnpack_opts = TfLiteXNNPackDelegateOptionsDefault(); + xnnpack_opts.num_threads = + GetXnnpackNumThreads(opts_has_delegate, opts_delegate); + return TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), + &TfLiteXNNPackDelegateDelete); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc new file mode 100644 index 000000000..81edb34e0 --- /dev/null +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc @@ -0,0 +1,181 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" + +namespace mediapipe { + +namespace { + +template +void CopyTensorBufferToInterpreter(const Tensor& input_tensor, + tflite::Interpreter* interpreter, + int input_tensor_index) { + auto input_tensor_view = input_tensor.GetCpuReadView(); + auto input_tensor_buffer = input_tensor_view.buffer(); + T* local_tensor_buffer = + interpreter->typed_input_tensor(input_tensor_index); + std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); +} + +template +void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, + int output_tensor_index, + Tensor* output_tensor) { + auto output_tensor_view = output_tensor->GetCpuWriteView(); + auto output_tensor_buffer = output_tensor_view.buffer(); + T* local_tensor_buffer = + interpreter->typed_output_tensor(output_tensor_index); + std::memcpy(output_tensor_buffer, local_tensor_buffer, + output_tensor->bytes()); +} + +} // namespace + +class InferenceInterpreterDelegateRunner : public InferenceRunner { + public: + InferenceInterpreterDelegateRunner( + api2::Packet model, + std::unique_ptr interpreter, + TfLiteDelegatePtr delegate) + : model_(std::move(model)), + interpreter_(std::move(interpreter)), + delegate_(std::move(delegate)) {} + + absl::StatusOr> Run( + const std::vector& input_tensors) override; + + private: + api2::Packet model_; + std::unique_ptr interpreter_; + TfLiteDelegatePtr delegate_; +}; + +absl::StatusOr> InferenceInterpreterDelegateRunner::Run( + const std::vector& input_tensors) { + // Read CPU input into tensors. + RET_CHECK_EQ(interpreter_->inputs().size(), input_tensors.size()); + for (int i = 0; i < input_tensors.size(); ++i) { + const TfLiteType input_tensor_type = + interpreter_->tensor(interpreter_->inputs()[i])->type; + switch (input_tensor_type) { + case TfLiteType::kTfLiteFloat16: + case TfLiteType::kTfLiteFloat32: { + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); + break; + } + case TfLiteType::kTfLiteUInt8: { + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); + break; + } + case TfLiteType::kTfLiteInt8: { + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); + break; + } + case TfLiteType::kTfLiteInt32: { + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); + break; + } + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported input tensor type:", input_tensor_type)); + } + } + + // Run inference. + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + + // Output result tensors (CPU). + const auto& tensor_indexes = interpreter_->outputs(); + std::vector output_tensors; + output_tensors.reserve(tensor_indexes.size()); + for (int i = 0; i < tensor_indexes.size(); ++i) { + TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); + Tensor::Shape shape{std::vector{ + tensor->dims->data, tensor->dims->data + tensor->dims->size}}; + switch (tensor->type) { + case TfLiteType::kTfLiteFloat16: + case TfLiteType::kTfLiteFloat32: + output_tensors.emplace_back(Tensor::ElementType::kFloat32, shape); + CopyTensorBufferFromInterpreter(interpreter_.get(), i, + &output_tensors.back()); + break; + case TfLiteType::kTfLiteUInt8: + output_tensors.emplace_back( + Tensor::ElementType::kUInt8, shape, + Tensor::QuantizationParameters{tensor->params.scale, + tensor->params.zero_point}); + CopyTensorBufferFromInterpreter(interpreter_.get(), i, + &output_tensors.back()); + break; + case TfLiteType::kTfLiteInt8: + output_tensors.emplace_back( + Tensor::ElementType::kInt8, shape, + Tensor::QuantizationParameters{tensor->params.scale, + tensor->params.zero_point}); + CopyTensorBufferFromInterpreter(interpreter_.get(), i, + &output_tensors.back()); + break; + case TfLiteType::kTfLiteInt32: + output_tensors.emplace_back(Tensor::ElementType::kInt32, shape); + CopyTensorBufferFromInterpreter(interpreter_.get(), i, + &output_tensors.back()); + break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported output tensor type:", + TfLiteTypeGetName(tensor->type))); + } + } + return output_tensors; +} + +absl::StatusOr> +CreateInferenceInterpreterDelegateRunner( + api2::Packet model, + api2::Packet op_resolver, TfLiteDelegatePtr delegate, + int interpreter_num_threads) { + tflite::InterpreterBuilder interpreter_builder(*model.Get(), + op_resolver.Get()); + if (delegate) { + interpreter_builder.AddDelegate(delegate.get()); + } +#if defined(__EMSCRIPTEN__) + interpreter_builder.SetNumThreads(1); +#else + interpreter_builder.SetNumThreads(interpreter_num_threads); +#endif // __EMSCRIPTEN__ + std::unique_ptr interpreter; + RET_CHECK_EQ(interpreter_builder(&interpreter), kTfLiteOk); + RET_CHECK(interpreter); + RET_CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk); + return std::make_unique( + std::move(model), std::move(interpreter), std::move(delegate)); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h new file mode 100644 index 000000000..bfe27868e --- /dev/null +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h @@ -0,0 +1,46 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/inference_runner.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/util/tflite/tflite_model_loader.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/interpreter.h" + +namespace mediapipe { + +using TfLiteDelegatePtr = + std::unique_ptr>; + +// Creates inference runner which run inference using newly initialized +// interpreter and provided `delegate`. +// +// `delegate` can be nullptr, in that case newly initialized interpreter will +// use what is available by default. +absl::StatusOr> +CreateInferenceInterpreterDelegateRunner( + api2::Packet model, + api2::Packet op_resolver, TfLiteDelegatePtr delegate, + int interpreter_num_threads); + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_INTERPRETER_DELEGATE_RUNNER_H_ diff --git a/mediapipe/calculators/tensor/inference_runner.h b/mediapipe/calculators/tensor/inference_runner.h new file mode 100644 index 000000000..ec9d17b8b --- /dev/null +++ b/mediapipe/calculators/tensor/inference_runner.h @@ -0,0 +1,19 @@ +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_ + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/tensor.h" + +namespace mediapipe { + +// Common interface to implement inference runners in MediaPipe. +class InferenceRunner { + public: + virtual ~InferenceRunner() = default; + virtual absl::StatusOr> Run( + const std::vector& inputs) = 0; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_RUNNER_H_ diff --git a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png index d28effc39..8e4073050 100644 Binary files a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png and b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad.png differ diff --git a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad_2x.png b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad_2x.png new file mode 100644 index 000000000..d28effc39 Binary files /dev/null and b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/76_c_Ipad_2x.png differ diff --git a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json index 8ae934c76..3ed9f5238 100644 --- a/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json +++ b/mediapipe/examples/ios/common/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -84,6 +84,7 @@ { "idiom" : "ipad", "size" : "76x76", + "filename" : "76_c_Ipad_2x.png", "scale" : "2x" }, { diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index e2c1820e7..28e42da97 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -324,6 +324,63 @@ TEST(BuilderTest, GraphIndexes) { EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } +class AnyAndSameTypeCalculator : public NodeIntf { + public: + static constexpr Input kAnyTypeInput{"INPUT"}; + static constexpr Output kAnyTypeOutput{"ANY_OUTPUT"}; + static constexpr Output> kSameTypeOutput{ + "SAME_OUTPUT"}; + + static constexpr Input kIntInput{"INT_INPUT"}; + // `SameType` usage for this output is only for testing purposes. + // + // `SameType` is designed to work with inputs of `AnyType` and, normally, you + // would not use `Output>` in a real calculator. You + // should write `Output` instead, since the type is known. + static constexpr Output> kSameIntOutput{ + "SAME_INT_OUTPUT"}; + + MEDIAPIPE_NODE_INTERFACE(AnyTypeCalculator, kAnyTypeInput, kAnyTypeOutput, + kSameTypeOutput); +}; + +TEST(BuilderTest, AnyAndSameTypeHandledProperly) { + builder::Graph graph; + builder::Source any_input = + graph[Input{"GRAPH_ANY_INPUT"}]; + builder::Source int_input = graph[Input{"GRAPH_INT_INPUT"}]; + + auto& node = graph.AddNode("AnyAndSameTypeCalculator"); + any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; + int_input >> node[AnyAndSameTypeCalculator::kIntInput]; + + builder::Source any_type_output = + node[AnyAndSameTypeCalculator::kAnyTypeOutput]; + any_type_output.SetName("any_type_output"); + + builder::Source same_type_output = + node[AnyAndSameTypeCalculator::kSameTypeOutput]; + same_type_output.SetName("same_type_output"); + builder::Source same_int_output = + node[AnyAndSameTypeCalculator::kSameIntOutput]; + same_int_output.SetName("same_int_output"); + + CalculatorGraphConfig expected = + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "AnyAndSameTypeCalculator" + input_stream: "INPUT:__stream_0" + input_stream: "INT_INPUT:__stream_1" + output_stream: "ANY_OUTPUT:any_type_output" + output_stream: "SAME_INT_OUTPUT:same_int_output" + output_stream: "SAME_OUTPUT:same_type_output" + } + input_stream: "GRAPH_ANY_INPUT:__stream_0" + input_stream: "GRAPH_INT_INPUT:__stream_1" + )pb"); + EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); +} + } // namespace test } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 80d2307ae..4ff726da0 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -27,6 +27,12 @@ using HolderBase = mediapipe::packet_internal::HolderBase; template class Packet; +struct DynamicType {}; + +struct AnyType : public DynamicType { + AnyType() = delete; +}; + // Type-erased packet. class PacketBase { public: @@ -148,9 +154,8 @@ inline void CheckCompatibleType(const HolderBase& holder, << " was requested."; } -struct Generic { - Generic() = delete; -}; +// TODO: remove usage of internal::Generic and simply use AnyType. +using Generic = ::mediapipe::api2::AnyType; template struct IsCompatibleType : std::false_type {}; diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index b972359e9..a408831bc 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -77,10 +77,6 @@ struct NoneType { NoneType() = delete; }; -struct DynamicType {}; - -struct AnyType : public DynamicType {}; - template class SameType : public DynamicType { public: diff --git a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java index 42fcdb4d8..c375aa61f 100644 --- a/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java +++ b/mediapipe/java/com/google/mediapipe/components/ExternalTextureConverter.java @@ -103,6 +103,18 @@ public class ExternalTextureConverter implements TextureFrameProducer { } } + /** + * Re-renders the current frame. Notifies all consumers as if it were a new frame. This should not + * typically be used but can be useful for cases where the consumer has lost ownership of the most + * recent frame and needs to get it again. This does nothing if no frame has yet been received. + */ + public void rerenderCurrentFrame() { + SurfaceTexture surfaceTexture = getSurfaceTexture(); + if (thread != null && surfaceTexture != null && thread.getHasReceivedFirstFrame()) { + thread.onFrameAvailable(surfaceTexture); + } + } + /** * Sets the new buffer pool size. This is safe to set at any time. * @@ -278,6 +290,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { private volatile SurfaceTexture internalSurfaceTexture = null; private int[] textures = null; private final List consumers; + private volatile boolean hasReceivedFirstFrame = false; private final Queue framesAvailable = new ArrayDeque<>(); private int framesInUse = 0; @@ -335,6 +348,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { } public void setSurfaceTexture(SurfaceTexture texture, int width, int height) { + hasReceivedFirstFrame = false; if (surfaceTexture != null) { surfaceTexture.setOnFrameAvailableListener(null); } @@ -381,6 +395,10 @@ public class ExternalTextureConverter implements TextureFrameProducer { return surfaceTexture != null ? surfaceTexture : internalSurfaceTexture; } + public boolean getHasReceivedFirstFrame() { + return hasReceivedFirstFrame; + } + @Override public void onFrameAvailable(SurfaceTexture surfaceTexture) { handler.post(() -> renderNext(surfaceTexture)); @@ -427,6 +445,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { // pending on the handler. When that happens, we should simply disregard the call. return; } + hasReceivedFirstFrame = true; try { synchronized (consumers) { boolean frameUpdated = false; diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 24e5c228c..48c9b181a 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -69,7 +69,8 @@ objc_library( "-Wno-shorten-64-to-32", ], sdk_frameworks = ["Accelerate"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + # This build rule is public to allow external customers to build their own iOS apps. + visibility = ["//visibility:public"], deps = [ ":CFHolder", ":util", @@ -124,7 +125,8 @@ objc_library( "CoreVideo", "Foundation", ], - visibility = ["//mediapipe/framework:mediapipe_internal"], + # This build rule is public to allow external customers to build their own iOS apps. + visibility = ["//visibility:public"], ) objc_library( @@ -166,7 +168,8 @@ objc_library( "Foundation", "GLKit", ], - visibility = ["//mediapipe/framework:mediapipe_internal"], + # This build rule is public to allow external customers to build their own iOS apps. + visibility = ["//visibility:public"], deps = [ ":mediapipe_framework_ios", ":mediapipe_gl_view_renderer", diff --git a/mediapipe/objc/MPPCameraInputSource.h b/mediapipe/objc/MPPCameraInputSource.h index 9bb7439b5..3438b3aeb 100644 --- a/mediapipe/objc/MPPCameraInputSource.h +++ b/mediapipe/objc/MPPCameraInputSource.h @@ -23,7 +23,7 @@ @property(nonatomic, getter=isAuthorized, readonly) BOOL authorized; /// Session preset to use for capturing. -@property(nonatomic) NSString *sessionPreset; +@property(nonatomic, nullable) NSString *sessionPreset; /// Which camera on an iOS device to use, assuming iOS device with more than one camera. @property(nonatomic) AVCaptureDevicePosition cameraPosition; diff --git a/mediapipe/python/solutions/download_utils.py b/mediapipe/python/solutions/download_utils.py index 3b69074b0..f582d5c64 100644 --- a/mediapipe/python/solutions/download_utils.py +++ b/mediapipe/python/solutions/download_utils.py @@ -17,21 +17,21 @@ import os import shutil import urllib.request -_OSS_URL_PREFIX = 'https://github.com/google/mediapipe/raw/master/' +_GCS_URL_PREFIX = 'https://storage.googleapis.com/mediapipe-assets/' def download_oss_model(model_path: str): - """Downloads the oss model from the MediaPipe GitHub repo if it doesn't exist in the package.""" + """Downloads the oss model from Google Cloud Storage if it doesn't exist in the package.""" mp_root_path = os.sep.join(os.path.abspath(__file__).split(os.sep)[:-4]) model_abspath = os.path.join(mp_root_path, model_path) if os.path.exists(model_abspath): return - model_url = _OSS_URL_PREFIX + model_path + model_url = _GCS_URL_PREFIX + model_path.split('/')[-1] print('Downloading model to ' + model_abspath) with urllib.request.urlopen(model_url) as response, open(model_abspath, 'wb') as out_file: if response.code != 200: raise ConnectionError('Cannot download ' + model_path + - ' from the MediaPipe Github repo.') + ' from Google Cloud Storage.') shutil.copyfileobj(response, out_file) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto index 1ecf8e072..9dd65a265 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.proto @@ -25,7 +25,7 @@ message AudioClassifierOptions { extend mediapipe.CalculatorOptions { optional AudioClassifierOptions ext = 451755788; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index c8985c98b..8b553dea4 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -43,3 +43,73 @@ cc_library( ], alwayslink = 1, ) + +mediapipe_proto_library( + name = "score_calibration_calculator_proto", + srcs = ["score_calibration_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "score_calibration_calculator", + srcs = ["score_calibration_calculator.cc"], + deps = [ + ":score_calibration_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/tasks/cc:common", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + ], + alwayslink = 1, +) + +cc_test( + name = "score_calibration_calculator_test", + srcs = ["score_calibration_calculator_test.cc"], + deps = [ + ":score_calibration_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "score_calibration_utils", + srcs = ["score_calibration_utils.cc"], + hdrs = ["score_calibration_utils.h"], + deps = [ + ":score_calibration_calculator_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "score_calibration_utils_test", + srcs = ["score_calibration_utils_test.cc"], + deps = [ + ":score_calibration_calculator_cc_proto", + ":score_calibration_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.cc new file mode 100644 index 000000000..c689cc255 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.cc @@ -0,0 +1,259 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" + +namespace mediapipe { +namespace api2 { + +using ::absl::StatusCode; +using ::mediapipe::tasks::CreateStatusWithPayload; +using ::mediapipe::tasks::MediaPipeTasksStatus; +using ::mediapipe::tasks::ScoreCalibrationCalculatorOptions; + +namespace { +// Used to prevent log(<=0.0) in ClampedLog() calls. +constexpr float kLogScoreMinimum = 1e-16; + +// Returns the following, depending on x: +// x => threshold: log(x) +// x < threshold: 2 * log(thresh) - log(2 * thresh - x) +// This form (a) is anti-symmetric about the threshold and (b) has continuous +// value and first derivative. This is done to prevent taking the log of values +// close to 0 which can lead to floating point errors and is better than simple +// clamping since it preserves order for scores less than the threshold. +float ClampedLog(float x, float threshold) { + if (x < threshold) { + return 2.0 * std::log(static_cast(threshold)) - + log(2.0 * threshold - x); + } + return std::log(static_cast(x)); +} +} // namespace + +// Applies score calibration to a tensor of score predictions, typically applied +// to the output of a classification or object detection model. +// +// See corresponding options for more details on the score calibration +// parameters and formula. +// +// Inputs: +// SCORES - std::vector +// A vector containing a single Tensor `x` of type kFloat32, representing +// the scores to calibrate. By default (i.e. if INDICES is not connected), +// x[i] will be calibrated using the sigmoid provided at index i in the +// options. +// INDICES - std::vector @Optional +// An optional vector containing a single Tensor `y` of type kFloat32 and +// same size as `x`. If provided, x[i] will be calibrated using the sigmoid +// provided at index y[i] (casted as an integer) in the options. `x` and `y` +// must contain the same number of elements. Typically used for object +// detection models. +// +// Outputs: +// CALIBRATED_SCORES - std::vector +// A vector containing a single Tensor of type kFloat32 and of the same size +// as the input tensors. Contains the output calibrated scores. +class ScoreCalibrationCalculator : public Node { + public: + static constexpr Input> kScoresIn{"SCORES"}; + static constexpr Input>::Optional kIndicesIn{"INDICES"}; + static constexpr Output> kScoresOut{"CALIBRATED_SCORES"}; + MEDIAPIPE_NODE_CONTRACT(kScoresIn, kIndicesIn, kScoresOut); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + ScoreCalibrationCalculatorOptions options_; + std::function score_transformation_; + + // Computes the calibrated score for the provided index. Does not check for + // out-of-bounds index. + float ComputeCalibratedScore(int index, float score); + // Same as above, but does check for out-of-bounds index. + absl::StatusOr SafeComputeCalibratedScore(int index, float score); +}; + +absl::Status ScoreCalibrationCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + // Sanity checks. + if (options_.sigmoids_size() == 0) { + return CreateStatusWithPayload(StatusCode::kInvalidArgument, + "Expected at least one sigmoid, found none.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + for (const auto& sigmoid : options_.sigmoids()) { + if (sigmoid.has_scale() && sigmoid.scale() < 0.0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("The scale parameter of the sigmoids must be " + "positive, found %f.", + sigmoid.scale()), + MediaPipeTasksStatus::kInvalidArgumentError); + } + } + // Set score transformation function once and for all. + switch (options_.score_transformation()) { + case tasks::ScoreCalibrationCalculatorOptions::IDENTITY: + score_transformation_ = [](float x) { return x; }; + break; + case tasks::ScoreCalibrationCalculatorOptions::LOG: + score_transformation_ = [](float x) { + return ClampedLog(x, kLogScoreMinimum); + }; + break; + case tasks::ScoreCalibrationCalculatorOptions::INVERSE_LOGISTIC: + score_transformation_ = [](float x) { + return (ClampedLog(x, kLogScoreMinimum) - + ClampedLog(1.0 - x, kLogScoreMinimum)); + }; + break; + default: + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat( + "Unsupported ScoreTransformation type: %s", + ScoreCalibrationCalculatorOptions::ScoreTransformation_Name( + options_.score_transformation())), + MediaPipeTasksStatus::kInvalidArgumentError); + } + return absl::OkStatus(); +} + +absl::Status ScoreCalibrationCalculator::Process(CalculatorContext* cc) { + RET_CHECK_EQ(kScoresIn(cc)->size(), 1); + const auto& scores = (*kScoresIn(cc))[0]; + RET_CHECK(scores.element_type() == Tensor::ElementType::kFloat32); + auto scores_view = scores.GetCpuReadView(); + const float* raw_scores = scores_view.buffer(); + int num_scores = scores.shape().num_elements(); + + auto output_tensors = std::make_unique>(); + output_tensors->reserve(1); + output_tensors->emplace_back(scores.element_type(), scores.shape()); + auto calibrated_scores = &output_tensors->back(); + auto calibrated_scores_view = calibrated_scores->GetCpuWriteView(); + float* raw_calibrated_scores = calibrated_scores_view.buffer(); + + if (kIndicesIn(cc).IsConnected()) { + RET_CHECK_EQ(kIndicesIn(cc)->size(), 1); + const auto& indices = (*kIndicesIn(cc))[0]; + RET_CHECK(indices.element_type() == Tensor::ElementType::kFloat32); + if (num_scores != indices.shape().num_elements()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of elements in the input " + "scores tensor (%d) and indices tensor (%d).", + num_scores, indices.shape().num_elements()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + auto indices_view = indices.GetCpuReadView(); + const float* raw_indices = indices_view.buffer(); + for (int i = 0; i < num_scores; ++i) { + // Use the "safe" flavor as we need to check that the externally provided + // indices are not out-of-bounds. + ASSIGN_OR_RETURN(raw_calibrated_scores[i], + SafeComputeCalibratedScore( + static_cast(raw_indices[i]), raw_scores[i])); + } + } else { + if (num_scores != options_.sigmoids_size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Mismatch between number of sigmoids (%d) and number " + "of elements in the input scores tensor (%d).", + options_.sigmoids_size(), num_scores), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + for (int i = 0; i < num_scores; ++i) { + // Use the "unsafe" flavor as we have already checked for out-of-bounds + // issues. + raw_calibrated_scores[i] = ComputeCalibratedScore(i, raw_scores[i]); + } + } + kScoresOut(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +float ScoreCalibrationCalculator::ComputeCalibratedScore(int index, + float score) { + const auto& sigmoid = options_.sigmoids(index); + + bool is_empty = + !sigmoid.has_scale() || !sigmoid.has_offset() || !sigmoid.has_slope(); + bool is_below_min_score = + sigmoid.has_min_score() && score < sigmoid.min_score(); + if (is_empty || is_below_min_score) { + return options_.default_score(); + } + + float transformed_score = score_transformation_(score); + float scale_shifted_score = + transformed_score * sigmoid.slope() + sigmoid.offset(); + // For numerical stability use 1 / (1+exp(-x)) when scale_shifted_score >= 0 + // and exp(x) / (1+exp(x)) when scale_shifted_score < 0. + float calibrated_score; + if (scale_shifted_score >= 0.0) { + calibrated_score = + sigmoid.scale() / + (1.0 + std::exp(static_cast(-scale_shifted_score))); + } else { + float score_exp = std::exp(static_cast(scale_shifted_score)); + calibrated_score = sigmoid.scale() * score_exp / (1.0 + score_exp); + } + // Scale is non-negative (checked in SigmoidFromLabelAndLine), + // thus calibrated_score should be in the range of [0, scale]. However, due to + // numberical stability issue, it may fall out of the boundary. Cap the value + // to [0, scale] instead. + return std::max(std::min(calibrated_score, sigmoid.scale()), 0.0f); +} + +absl::StatusOr ScoreCalibrationCalculator::SafeComputeCalibratedScore( + int index, float score) { + if (index < 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Expected positive indices, found %d.", index), + MediaPipeTasksStatus::kInvalidArgumentError); + } + if (index > options_.sigmoids_size()) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + absl::StrFormat("Unable to get score calibration parameters for index " + "%d : only %d sigmoids were provided.", + index, options_.sigmoids_size()), + MediaPipeTasksStatus::kMetadataInconsistencyError); + } + return ComputeCalibratedScore(index, score); +} + +MEDIAPIPE_REGISTER_NODE(ScoreCalibrationCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto new file mode 100644 index 000000000..11d944c93 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto @@ -0,0 +1,67 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks; + +import "mediapipe/framework/calculator.proto"; + +message ScoreCalibrationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional ScoreCalibrationCalculatorOptions ext = 470204318; + } + + // Score calibration parameters for one individual category. The formula used + // to transform the uncalibrated score `x` is: + // * `f(x) = scale / (1 + e^-(slope * g(x) + offset))` if `x > min_score` or + // if no min_score has been specified, + // * `f(x) = default_score` otherwise or if no scale, slope or offset have + // been specified. + // + // Where: + // * scale must be positive, + // * g(x) is a global (i.e. category independent) transform specified using + // the score_transformation field, + // * default_score is a global parameter defined below. + // + // There should be exactly one sigmoid per number of supported output + // categories in the model, with either: + // * no fields set, + // * scale, slope and offset set, + // * all fields set. + message Sigmoid { + optional float scale = 1; + optional float slope = 2; + optional float offset = 3; + optional float min_score = 4; + } + repeated Sigmoid sigmoids = 1; + + // Score transformation that defines the `g(x)` function in the above formula. + enum ScoreTransformation { + UNSPECIFIED = 0; + // g(x) = x. + IDENTITY = 1; + // g(x) = log(x). + LOG = 2; + // g(x) = log(x) - log(1 - x). + INVERSE_LOGISTIC = 3; + } + optional ScoreTransformation score_transformation = 2 [default = IDENTITY]; + + // Default score. + optional float default_score = 3; +} diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator_test.cc new file mode 100644 index 000000000..8134d86d2 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_calculator_test.cc @@ -0,0 +1,309 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::ParseTextProtoOrDie; +using ::testing::HasSubstr; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using Node = ::mediapipe::CalculatorGraphConfig::Node; + +// Builds the graph and feeds inputs. +void BuildGraph(CalculatorRunner* runner, std::vector scores, + std::optional> indices = std::nullopt) { + auto scores_tensors = std::make_unique>(); + scores_tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, static_cast(scores.size())}); + auto scores_view = scores_tensors->back().GetCpuWriteView(); + float* scores_buffer = scores_view.buffer(); + ASSERT_NE(scores_buffer, nullptr); + for (int i = 0; i < scores.size(); ++i) { + scores_buffer[i] = scores[i]; + } + auto& input_scores_packets = runner->MutableInputs()->Tag("SCORES").packets; + input_scores_packets.push_back( + mediapipe::Adopt(scores_tensors.release()).At(mediapipe::Timestamp(0))); + + if (indices.has_value()) { + auto indices_tensors = std::make_unique>(); + indices_tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{1, static_cast(indices->size())}); + auto indices_view = indices_tensors->back().GetCpuWriteView(); + float* indices_buffer = indices_view.buffer(); + ASSERT_NE(indices_buffer, nullptr); + for (int i = 0; i < indices->size(); ++i) { + indices_buffer[i] = static_cast((*indices)[i]); + } + auto& input_indices_packets = + runner->MutableInputs()->Tag("INDICES").packets; + input_indices_packets.push_back(mediapipe::Adopt(indices_tensors.release()) + .At(mediapipe::Timestamp(0))); + } +} + +// Compares the provided tensor contents with the expected values. +void ValidateResult(const Tensor& actual, const std::vector& expected) { + EXPECT_EQ(actual.element_type(), Tensor::ElementType::kFloat32); + EXPECT_EQ(expected.size(), actual.shape().num_elements()); + auto view = actual.GetCpuReadView(); + auto buffer = view.buffer(); + for (int i = 0; i < expected.size(); ++i) { + EXPECT_FLOAT_EQ(expected[i], buffer[i]); + } +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithNoSigmoid) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] {} + } + )pb")); + + BuildGraph(&runner, {0.5, 0.5, 0.5}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Expected at least one sigmoid, found none")); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithNegativeScale) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { slope: 1 offset: 1 scale: -1 } + } + } + )pb")); + + BuildGraph(&runner, {0.5, 0.5, 0.5}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("The scale parameter of the sigmoids must be positive")); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithUnspecifiedTransformation) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { slope: 1 offset: 1 scale: 1 } + score_transformation: UNSPECIFIED + } + } + )pb")); + + BuildGraph(&runner, {0.5, 0.5, 0.5}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Unsupported ScoreTransformation type")); +} + +// Struct holding the parameters for parameterized tests below. +struct CalibrationTestParams { + // The score transformation to apply. + std::string score_transformation; + // The expected results. + std::vector expected_results; +}; + +class CalibrationWithoutIndicesTest + : public TestWithParam {}; + +TEST_P(CalibrationWithoutIndicesTest, Succeeds) { + CalculatorRunner runner(ParseTextProtoOrDie(absl::StrFormat( + R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: %s + default_score: 0.2 + } + } + )pb", + GetParam().score_transformation))); + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}); + MP_ASSERT_OK(runner.Run()); + const Tensor& results = runner.Outputs() + .Get("CALIBRATED_SCORES", 0) + .packets[0] + .Get>()[0]; + + ValidateResult(results, GetParam().expected_results); +} + +INSTANTIATE_TEST_SUITE_P( + ScoreCalibrationCalculatorTest, CalibrationWithoutIndicesTest, + Values(CalibrationTestParams{.score_transformation = "IDENTITY", + .expected_results = {0.4948505976, + 0.5059588508, 0.2, 0.2}}, + CalibrationTestParams{ + .score_transformation = "LOG", + .expected_results = {0.2976901255, 0.3393665735, 0.2, 0.2}}, + CalibrationTestParams{ + .score_transformation = "INVERSE_LOGISTIC", + .expected_results = {0.3203217641, 0.3778080605, 0.2, 0.2}}), + [](const TestParamInfo& info) { + return info.param.score_transformation; + }); + +TEST(ScoreCalibrationCalculatorTest, FailsWithMissingSigmoids) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: LOG + default_score: 0.2 + } + } + )pb")); + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5, 0.6}); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Mismatch between number of sigmoids")); +} + +TEST(ScoreCalibrationCalculatorTest, SucceedsWithIndices) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + input_stream: "INDICES:indices" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: IDENTITY + default_score: 0.2 + } + } + )pb")); + std::vector indices = {1, 2, 3, 0}; + + BuildGraph(&runner, {0.3, 0.4, 0.5, 0.2}, indices); + MP_ASSERT_OK(runner.Run()); + const Tensor& results = runner.Outputs() + .Get("CALIBRATED_SCORES", 0) + .packets[0] + .Get>()[0]; + ValidateResult(results, {0.5059588508, 0.2, 0.2, 0.4948505976}); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithNegativeIndex) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + input_stream: "INDICES:indices" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: IDENTITY + default_score: 0.2 + } + } + )pb")); + std::vector indices = {0, 1, 2, -1}; + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}, indices); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), HasSubstr("Expected positive indices")); +} + +TEST(ScoreCalibrationCalculatorTest, FailsWithOutOfBoundsIndex) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "ScoreCalibrationCalculator" + input_stream: "SCORES:scores" + input_stream: "INDICES:indices" + output_stream: "CALIBRATED_SCORES:calibrated_scores" + options { + [mediapipe.tasks.ScoreCalibrationCalculatorOptions.ext] { + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.3 } + sigmoids { scale: 0.9 slope: 0.5 offset: 0.1 min_score: 0.6 } + sigmoids {} + score_transformation: IDENTITY + default_score: 0.2 + } + } + )pb")); + std::vector indices = {0, 1, 5, 3}; + + BuildGraph(&runner, {0.2, 0.3, 0.4, 0.5}, indices); + auto status = runner.Run(); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("Unable to get score calibration parameters for index")); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_utils.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.cc new file mode 100644 index 000000000..120344be6 --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.cc @@ -0,0 +1,115 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { + +namespace { +// Converts ScoreTransformation type from TFLite Metadata to calculator options. +ScoreCalibrationCalculatorOptions::ScoreTransformation +ConvertScoreTransformationType(tflite::ScoreTransformationType type) { + switch (type) { + case tflite::ScoreTransformationType_IDENTITY: + return ScoreCalibrationCalculatorOptions::IDENTITY; + case tflite::ScoreTransformationType_LOG: + return ScoreCalibrationCalculatorOptions::LOG; + case tflite::ScoreTransformationType_INVERSE_LOGISTIC: + return ScoreCalibrationCalculatorOptions::INVERSE_LOGISTIC; + } +} + +// Parses a single line of the score calibration file into the provided sigmoid. +absl::Status FillSigmoidFromLine( + absl::string_view line, + ScoreCalibrationCalculatorOptions::Sigmoid* sigmoid) { + if (line.empty()) { + return absl::OkStatus(); + } + std::vector str_params = absl::StrSplit(line, ','); + if (str_params.size() != 3 && str_params.size() != 4) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Expected 3 or 4 parameters per line in score " + "calibration file, got %d.", + str_params.size()), + MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError); + } + std::vector params(str_params.size()); + for (int i = 0; i < str_params.size(); ++i) { + if (!absl::SimpleAtof(str_params[i], ¶ms[i])) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "Could not parse score calibration parameter as float: %s.", + str_params[i]), + MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError); + } + } + if (params[0] < 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "The scale parameter of the sigmoids must be positive, found %f.", + params[0]), + MediaPipeTasksStatus::kMetadataMalformedScoreCalibrationError); + } + sigmoid->set_scale(params[0]); + sigmoid->set_slope(params[1]); + sigmoid->set_offset(params[2]); + if (params.size() == 4) { + sigmoid->set_min_score(params[3]); + } + return absl::OkStatus(); +} +} // namespace + +absl::Status ConfigureScoreCalibration( + tflite::ScoreTransformationType score_transformation, float default_score, + absl::string_view score_calibration_file, + ScoreCalibrationCalculatorOptions* calculator_options) { + calculator_options->set_score_transformation( + ConvertScoreTransformationType(score_transformation)); + calculator_options->set_default_score(default_score); + + if (score_calibration_file.empty()) { + return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, + "Expected non-empty score calibration file.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + std::vector lines = + absl::StrSplit(score_calibration_file, '\n'); + for (const auto& line : lines) { + auto* sigmoid = calculator_options->add_sigmoids(); + MP_RETURN_IF_ERROR(FillSigmoidFromLine(line, sigmoid)); + } + + return absl::OkStatus(); +} + +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_utils.h b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.h new file mode 100644 index 000000000..5c3d446ee --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_utils.h @@ -0,0 +1,38 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { + +// Populates ScoreCalibrationCalculatorOptions given a TFLite Metadata score +// transformation type, default threshold and score calibration AssociatedFile +// contents, as specified in `TENSOR_AXIS_SCORE_CALIBRATION`: +// https://github.com/google/mediapipe/blob/master/mediapipe/tasks/metadata/metadata_schema.fbs +absl::Status ConfigureScoreCalibration( + tflite::ScoreTransformationType score_transformation, float default_score, + absl::string_view score_calibration_file, + ScoreCalibrationCalculatorOptions* options); + +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CALCULATORS_SCORE_CALIBRATION_UTILS_H_ diff --git a/mediapipe/tasks/cc/components/calculators/score_calibration_utils_test.cc b/mediapipe/tasks/cc/components/calculators/score_calibration_utils_test.cc new file mode 100644 index 000000000..dc7fd90cd --- /dev/null +++ b/mediapipe/tasks/cc/components/calculators/score_calibration_utils_test.cc @@ -0,0 +1,130 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace { + +using ::testing::HasSubstr; + +TEST(ConfigureScoreCalibrationTest, SucceedsWithoutTrailingNewline) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("\n", "0.1,0.2,0.3\n", "0.4,0.5,0.6,0.7"); + + MP_ASSERT_OK(ConfigureScoreCalibration( + tflite::ScoreTransformationType_IDENTITY, + /*default_score=*/0.5, score_calibration_file, &options)); + + EXPECT_THAT( + options, + EqualsProto(ParseTextProtoOrDie(R"pb( + score_transformation: IDENTITY + default_score: 0.5 + sigmoids {} + sigmoids { scale: 0.1 slope: 0.2 offset: 0.3 } + sigmoids { scale: 0.4 slope: 0.5 offset: 0.6 min_score: 0.7 } + )pb"))); +} + +TEST(ConfigureScoreCalibrationTest, SucceedsWithTrailingNewline) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("\n", "0.1,0.2,0.3\n", "0.4,0.5,0.6,0.7\n"); + + MP_ASSERT_OK(ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options)); + + EXPECT_THAT( + options, + EqualsProto(ParseTextProtoOrDie(R"pb( + score_transformation: LOG + default_score: 0.5 + sigmoids {} + sigmoids { scale: 0.1 slope: 0.2 offset: 0.3 } + sigmoids { scale: 0.4 slope: 0.5 offset: 0.6 min_score: 0.7 } + sigmoids {} + )pb"))); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithEmptyFile) { + ScoreCalibrationCalculatorOptions options; + + auto status = + ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + /*score_calibration_file=*/"", &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Expected non-empty score calibration file")); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithInvalidNumParameters) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = absl::StrCat("0.1,0.2,0.3\n", "0.1,0.2"); + + auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(status.message(), + HasSubstr("Expected 3 or 4 parameters per line")); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithNonParseableParameter) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("0.1,0.2,0.3\n", "0.1,foo,0.3\n"); + + auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("Could not parse score calibration parameter as float")); +} + +TEST(ConfigureScoreCalibrationTest, FailsWithNegativeScaleParameter) { + ScoreCalibrationCalculatorOptions options; + std::string score_calibration_file = + absl::StrCat("0.1,0.2,0.3\n", "-0.1,0.2,0.3\n"); + + auto status = ConfigureScoreCalibration(tflite::ScoreTransformationType_LOG, + /*default_score=*/0.5, + score_calibration_file, &options); + + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + status.message(), + HasSubstr("The scale parameter of the sigmoids must be positive")); +} + +} // namespace +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc index 5e40d5d82..4ea41b163 100644 --- a/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -159,12 +159,12 @@ absl::Status TensorsToSegmentationCalculator::Process( std::tie(output_width, output_height) = kOutputSizeIn(cc).Get(); } Shape output_shape = { - .height = output_height, - .width = output_width, - .channels = options_.segmenter_options().output_type() == - SegmenterOptions::CATEGORY_MASK - ? 1 - : input_shape.channels}; + /* height= */ output_height, + /* width= */ output_width, + /* channels= */ options_.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK + ? 1 + : input_shape.channels}; std::vector segmented_masks = GetSegmentationResult( input_shape, output_shape, input_tensor.GetCpuReadView().buffer()); diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.cc b/mediapipe/tasks/cc/components/classification_postprocessing.cc index ebc34b8fc..fc28391bb 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.cc +++ b/mediapipe/tasks/cc/components/classification_postprocessing.cc @@ -148,8 +148,9 @@ absl::StatusOr GetClassificationHeadsProperties( num_output_tensors, output_tensors_metadata->size()), MediaPipeTasksStatus::kMetadataInconsistencyError); } - return ClassificationHeadsProperties{.num_heads = num_output_tensors, - .quantized = num_quantized_tensors > 0}; + return ClassificationHeadsProperties{ + /* num_heads= */ num_output_tensors, + /* quantized= */ num_quantized_tensors > 0}; } // Builds the label map from the tensor metadata, if available. diff --git a/mediapipe/tasks/cc/components/image_preprocessing.cc b/mediapipe/tasks/cc/components/image_preprocessing.cc index 835196877..18958a911 100644 --- a/mediapipe/tasks/cc/components/image_preprocessing.cc +++ b/mediapipe/tasks/cc/components/image_preprocessing.cc @@ -226,12 +226,14 @@ class ImagePreprocessingSubgraph : public Subgraph { // Connect outputs. return { - .tensors = image_to_tensor[Output>(kTensorsTag)], - .matrix = image_to_tensor[Output>(kMatrixTag)], - .letterbox_padding = - image_to_tensor[Output>(kLetterboxPaddingTag)], - .image_size = image_size[Output>(kSizeTag)], - .image = pass_through[Output("")], + /* tensors= */ image_to_tensor[Output>( + kTensorsTag)], + /* matrix= */ + image_to_tensor[Output>(kMatrixTag)], + /* letterbox_padding= */ + image_to_tensor[Output>(kLetterboxPaddingTag)], + /* image_size= */ image_size[Output>(kSizeTag)], + /* image= */ pass_through[Output("")], }; } }; diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index 7e20d8ef4..8a219bb80 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -18,8 +18,18 @@ limitations under the License. #include #include #include + +#ifdef ABSL_HAVE_MMAP #include +#endif + +#ifdef _WIN32 +#include +#include +#include +#else #include +#endif #include #include @@ -44,12 +54,17 @@ using ::absl::StatusCode; // file descriptor correctly, as according to mmap(2), the offset used in mmap // must be a multiple of sysconf(_SC_PAGE_SIZE). int64 GetPageSizeAlignedOffset(int64 offset) { +#ifdef _WIN32 + // mmap is not used on Windows + return -1; +#else int64 aligned_offset = offset; int64 page_size = sysconf(_SC_PAGE_SIZE); if (offset % page_size != 0) { aligned_offset = offset / page_size * page_size; } return aligned_offset; +#endif } } // namespace @@ -69,6 +84,12 @@ ExternalFileHandler::CreateFromExternalFile( } absl::Status ExternalFileHandler::MapExternalFile() { +// TODO: Add Windows support +#ifdef _WIN32 + return CreateStatusWithPayload(StatusCode::kFailedPrecondition, + "File loading is not yet supported on Windows", + MediaPipeTasksStatus::kFileReadError); +#else if (!external_file_.file_content().empty()) { return absl::OkStatus(); } @@ -169,6 +190,7 @@ absl::Status ExternalFileHandler::MapExternalFile() { MediaPipeTasksStatus::kFileMmapError); } return absl::OkStatus(); +#endif } absl::string_view ExternalFileHandler::GetFileContent() { @@ -182,9 +204,11 @@ absl::string_view ExternalFileHandler::GetFileContent() { } ExternalFileHandler::~ExternalFileHandler() { +#ifndef _WIN32 if (buffer_ != MAP_FAILED) { munmap(buffer_, buffer_aligned_size_); } +#endif if (owned_fd_ >= 0) { close(owned_fd_); } diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto index 6393c822e..42f2bbc85 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - +// TODO Refactor naming and class structure of hand related Tasks. syntax = "proto2"; package mediapipe.tasks.vision.hand_gesture_recognizer.proto; diff --git a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_graph.cc index f6bfbd1bf..c5677cd98 100644 --- a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_graph.cc @@ -388,13 +388,13 @@ class HandLandmarkDetectorGraph : public core::ModelTaskGraph { hand_rect_transformation[Output("")]; return {{ - .hand_landmarks = projected_landmarks, - .world_hand_landmarks = projected_world_landmarks, - .hand_rect_next_frame = hand_rect_next_frame, - .hand_presence = hand_presence, - .hand_presence_score = hand_presence_score, - .handedness = handedness, - .image_size = image_size, + /* hand_landmarks= */ projected_landmarks, + /* world_hand_landmarks= */ projected_world_landmarks, + /* hand_rect_next_frame= */ hand_rect_next_frame, + /* hand_presence= */ hand_presence, + /* hand_presence_score= */ hand_presence_score, + /* handedness= */ handedness, + /* image_size= */ image_size, }}; } }; diff --git a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto b/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto index 3de64e593..a2cfc7eaf 100644 --- a/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmark/hand_landmark_detector_options.proto @@ -24,7 +24,7 @@ message HandLandmarkDetectorOptions { extend mediapipe.CalculatorOptions { optional HandLandmarkDetectorOptions ext = 462713202; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto b/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto index 1fa221179..21fb3cd8c 100644 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto +++ b/mediapipe/tasks/cc/vision/image_classification/image_classifier_options.proto @@ -25,7 +25,7 @@ message ImageClassifierOptions { extend mediapipe.CalculatorOptions { optional ImageClassifierOptions ext = 456383383; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc index 5b8cf8675..014f11352 100644 --- a/mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classification/image_classifier_test.cc @@ -142,7 +142,7 @@ TEST_F(CreateTest, FailsWithSelectiveOpResolverMissingOps) { // interpreter errors (e.g., "Encountered unresolved custom op"). EXPECT_EQ(image_classifier_or.status().code(), absl::StatusCode::kInternal); EXPECT_THAT(image_classifier_or.status().message(), - HasSubstr("interpreter_builder(&interpreter_) == kTfLiteOk")); + HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); } TEST_F(CreateTest, FailsWithMissingModel) { auto image_classifier_or = diff --git a/mediapipe/tasks/cc/vision/segmentation/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD similarity index 81% rename from mediapipe/tasks/cc/vision/segmentation/BUILD rename to mediapipe/tasks/cc/vision/image_segmenter/BUILD index cc4d8236f..cb0482e42 100644 --- a/mediapipe/tasks/cc/vision/segmentation/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -12,34 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "image_segmenter_options_proto", - srcs = ["image_segmenter_options.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components:segmenter_options_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", - ], -) - cc_library( name = "image_segmenter", srcs = ["image_segmenter.cc"], hdrs = ["image_segmenter.h"], deps = [ ":image_segmenter_graph", - ":image_segmenter_options_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", - "//mediapipe/tasks/cc/core:base_task_api", - "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/components:segmenter_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", @@ -51,7 +42,6 @@ cc_library( name = "image_segmenter_graph", srcs = ["image_segmenter_graph.cc"], deps = [ - ":image_segmenter_options_cc_proto", "//mediapipe/calculators/core:merge_to_vector_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/tensor:image_to_tensor_calculator", @@ -70,6 +60,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", @@ -82,9 +73,9 @@ cc_library( ) cc_library( - name = "custom_op_resolvers", - srcs = ["custom_op_resolvers.cc"], - hdrs = ["custom_op_resolvers.h"], + name = "image_segmenter_op_resolvers", + srcs = ["image_segmenter_op_resolvers.cc"], + hdrs = ["image_segmenter_op_resolvers.h"], deps = [ "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:max_pool_argmax", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc new file mode 100644 index 000000000..090149d92 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -0,0 +1,134 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/segmenter_options.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace { + +constexpr char kSegmentationStreamName[] = "segmented_mask_out"; +constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.vision.ImageSegmenterGraph"; + +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Image; +using ImageSegmenterOptionsProto = + image_segmenter::proto::ImageSegmenterOptions; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// "mediapipe.tasks.vision.ImageSegmenterGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options, + bool enable_flow_limiting) { + api2::builder::Graph graph; + auto& task_subgraph = graph.AddNode(kSubgraphTypeName); + task_subgraph.GetOptions().Swap(options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> + graph.Out(kGroupedSegmentationTag); + task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> + graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator( + graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag); + } + graph.In(kImageTag) >> task_subgraph.In(kImageTag); + return graph.GetConfig(); +} + +// Converts the user-facing ImageSegmenterOptions struct to the internal +// ImageSegmenterOptions proto. +std::unique_ptr ConvertImageSegmenterOptionsToProto( + ImageSegmenterOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + options_proto->set_display_names_locale(options->display_names_locale); + switch (options->output_type) { + case ImageSegmenterOptions::OutputType::CATEGORY_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CATEGORY_MASK); + break; + case ImageSegmenterOptions::OutputType::CONFIDENCE_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CONFIDENCE_MASK); + break; + } + switch (options->activation) { + case ImageSegmenterOptions::Activation::NONE: + options_proto->mutable_segmenter_options()->set_activation( + SegmenterOptions::NONE); + break; + case ImageSegmenterOptions::Activation::SIGMOID: + options_proto->mutable_segmenter_options()->set_activation( + SegmenterOptions::SIGMOID); + break; + case ImageSegmenterOptions::Activation::SOFTMAX: + options_proto->mutable_segmenter_options()->set_activation( + SegmenterOptions::SOFTMAX); + break; + } + return options_proto; +} + +} // namespace + +absl::StatusOr> ImageSegmenter::Create( + std::unique_ptr options) { + auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr> ImageSegmenter::Segment( + mediapipe::Image image) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData({{kImageInStreamName, + mediapipe::MakePacket(std::move(image))}})); + return output_packets[kSegmentationStreamName].Get>(); +} + +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h new file mode 100644 index 000000000..00c63953a --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -0,0 +1,123 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_ + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" +#include "tensorflow/lite/core/api/op_resolver.h" +#include "tensorflow/lite/kernels/register.h" + +namespace mediapipe { +namespace tasks { +namespace vision { + +// The options for configuring a mediapipe image segmenter task. +struct ImageSegmenterOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // Image segmenter has three running modes: + // 1) The image mode for segmenting image on single image inputs. + // 2) The video mode for segmenting image on the decoded frames of a video. + // 3) The live stream mode for segmenting image on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the segmentation results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // The locale to use for display names specified through the TFLite Model + // Metadata, if any. Defaults to English. + std::string display_names_locale = "en"; + + // The output type of segmentation results. + enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK = 0, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK = 1, + }; + + OutputType output_type = OutputType::CATEGORY_MASK; + + // The activation function used on the raw segmentation model output. + enum Activation { + NONE = 0, // No activation function is used. + SIGMOID = 1, // Assumes 1-channel input tensor. + SOFTMAX = 2, // Assumes multi-channel input tensor. + }; + + Activation activation = Activation::NONE; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function>, + const Image&, int64)> + result_callback = nullptr; +}; + +// Performs segmentation on images. +// +// The API expects a TFLite model with mandatory TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - RGB and greyscale inputs are supported (`channels` is required to be +// 1 or 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// Output tensors: +// (kTfLiteUInt8/kTfLiteFloat32) +// - list of segmented masks. +// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. +// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size +// `cahnnels`. +// - batch is always 1 +// An example of such model can be found at: +// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 +class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an ImageSegmenter from the provided options. A non-default + // OpResolver can be specified in the BaseOptions of ImageSegmenterOptions, + // to support custom Ops of the segmentation model. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Runs the actual segmentation task. + absl::StatusOr> Segment(mediapipe::Image image); +}; + +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_H_ diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc similarity index 84% rename from mediapipe/tasks/cc/vision/segmentation/image_segmenter_graph.cc rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index b960fd930..d843689e2 100644 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -53,6 +53,7 @@ using ::mediapipe::api2::builder::MultiSource; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::SegmenterOptions; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::mediapipe::tasks::vision::image_segmenter::proto::ImageSegmenterOptions; using ::tflite::Tensor; using ::tflite::TensorMetadata; using LabelItems = mediapipe::proto_ns::Map; @@ -63,6 +64,14 @@ constexpr char kImageTag[] = "IMAGE"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +// Struct holding the different output streams produced by the image segmenter +// subgraph. +struct ImageSegmenterOutputs { + std::vector> segmented_masks; + // The same as the input image, mainly used for live stream mode. + Source image; +}; + } // namespace absl::Status SanityCheckOptions(const ImageSegmenterOptions& options) { @@ -140,6 +149,10 @@ absl::StatusOr GetOutputTensor( // An "mediapipe.tasks.vision.ImageSegmenterGraph" performs semantic // segmentation. +// Two kinds of outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. +// Users can retrieve segmented mask of only particular category/channel from +// SEGMENTATION, and users can also get all segmented masks from +// GROUPED_SEGMENTATION. // - Accepts CPU input images and outputs segmented masks on CPU. // // Inputs: @@ -147,8 +160,13 @@ absl::StatusOr GetOutputTensor( // Image to perform segmentation on. // // Outputs: -// SEGMENTATION - SEGMENTATION -// Segmented masks. +// SEGMENTATION - mediapipe::Image @Multiple +// Segmented masks for individual category. Segmented mask of single +// category can be accessed by index based output stream. +// GROUPED_SEGMENTATION - std::vector +// The output segmented masks grouped in a vector. +// IMAGE - mediapipe::Image +// The image that image segmenter runs on. // // Example: // node { @@ -156,7 +174,8 @@ absl::StatusOr GetOutputTensor( // input_stream: "IMAGE:image" // output_stream: "SEGMENTATION:segmented_masks" // options { -// [mediapipe.tasks.ImageSegmenterOptions.ext] { +// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterOptions.ext] +// { // segmenter_options { // output_type: CONFIDENCE_MASK // activation: SOFTMAX @@ -171,20 +190,22 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto segmentations, + ASSIGN_OR_RETURN(auto output_streams, BuildSegmentationTask( sc->Options(), *model_resources, graph[Input(kImageTag)], graph)); auto& merge_images_to_vector = graph.AddNode("MergeImagesToVectorCalculator"); - for (int i = 0; i < segmentations.size(); ++i) { - segmentations[i] >> merge_images_to_vector[Input::Multiple("")][i]; - segmentations[i] >> graph[Output::Multiple(kSegmentationTag)][i]; + for (int i = 0; i < output_streams.segmented_masks.size(); ++i) { + output_streams.segmented_masks[i] >> + merge_images_to_vector[Input::Multiple("")][i]; + output_streams.segmented_masks[i] >> + graph[Output::Multiple(kSegmentationTag)][i]; } merge_images_to_vector.Out("") >> graph[Output>(kGroupedSegmentationTag)]; - + output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -193,12 +214,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { // builder::Graph instance. The segmentation pipeline takes images // (mediapipe::Image) as the input and returns segmented image mask as output. // - // task_options: the mediapipe tasks ImageSegmenterOptions. + // task_options: the mediapipe tasks ImageSegmenterOptions proto. // model_resources: the ModelSources object initialized from a segmentation // model file with model metadata. // image_in: (mediapipe::Image) stream to run segmentation on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr>> BuildSegmentationTask( + absl::StatusOr BuildSegmentationTask( const ImageSegmenterOptions& task_options, const core::ModelResources& model_resources, Source image_in, Graph& graph) { @@ -246,7 +267,10 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { tensor_to_images[Output::Multiple(kSegmentationTag)][i])); } } - return segmented_masks; + return {{ + .segmented_masks = segmented_masks, + .image = preprocessing[Output(kImageTag)], + }}; } }; diff --git a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc similarity index 96% rename from mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.cc rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc index b24b426ad..cd3b5690f 100644 --- a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" diff --git a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h similarity index 81% rename from mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h index 2b185d792..a0538a674 100644 --- a/mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ #include "tensorflow/lite/kernels/register.h" @@ -34,4 +34,4 @@ class SelfieSegmentationModelOpResolver } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_CUSTOM_OP_RESOLVERS_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc similarity index 79% rename from mediapipe/tasks/cc/vision/segmentation/image_segmenter_test.cc rename to mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 7e77cece6..f43d28fca 100644 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" #include #include @@ -32,8 +32,8 @@ limitations under the License. #include "mediapipe/tasks/cc/components/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "mediapipe/tasks/cc/vision/segmentation/custom_op_resolvers.h" -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" @@ -46,11 +46,8 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::ImageSegmenterOptions; -using ::mediapipe::tasks::SegmenterOptions; using ::testing::HasSubstr; using ::testing::Optional; -using ::tflite::ops::builtin::BuiltinOpResolver; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kDeeplabV3WithMetadata[] = "deeplabv3.tflite"; @@ -167,25 +164,25 @@ class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); - MP_ASSERT_OK(ImageSegmenter::Create(std::move(options), - absl::make_unique())); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->base_options.op_resolver = absl::make_unique(); + MP_ASSERT_OK(ImageSegmenter::Create(std::move(options))); } TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); - - auto segmenter_or = ImageSegmenter::Create( - std::move(options), absl::make_unique()); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->base_options.op_resolver = + absl::make_unique(); + auto segmenter_or = ImageSegmenter::Create(std::move(options)); // TODO: Make MediaPipe InferenceCalculator report the detailed // interpreter errors (e.g., "Encountered unresolved custom op"). EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal); EXPECT_THAT( segmenter_or.status().message(), - testing::HasSubstr("interpreter_builder(&interpreter_) == kTfLiteOk")); + testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); } TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { @@ -202,24 +199,6 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { MediaPipeTasksStatus::kRunnerInitializationError)))); } -TEST_F(CreateFromOptionsTest, FailsWithUnspecifiedOutputType) { - auto options = std::make_unique(); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::UNSPECIFIED); - - auto segmenter_or = ImageSegmenter::Create( - std::move(options), absl::make_unique()); - - EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument); - EXPECT_THAT(segmenter_or.status().message(), - HasSubstr("`output_type` must not be UNSPECIFIED")); - EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), - Optional(absl::Cord(absl::StrCat( - MediaPipeTasksStatus::kRunnerInitializationError)))); -} - class SegmentationTest : public tflite_shims::testing::Test {}; TEST_F(SegmentationTest, SucceedsWithCategoryMask) { @@ -228,10 +207,10 @@ TEST_F(SegmentationTest, SucceedsWithCategoryMask) { DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "segmentation_input_rotation0.jpg"))); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CATEGORY_MASK); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto category_masks, segmenter->Segment(image)); @@ -253,12 +232,11 @@ TEST_F(SegmentationTest, SucceedsWithConfidenceMask) { Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - options->mutable_segmenter_options()->set_activation( - SegmenterOptions::SOFTMAX); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata)); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image)); @@ -281,17 +259,15 @@ TEST_F(SegmentationTest, SucceedsSelfie128x128Segmentation) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - options->mutable_segmenter_options()->set_activation( - SegmenterOptions::SOFTMAX); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata)); - MP_ASSERT_OK_AND_ASSIGN( - std::unique_ptr segmenter, - ImageSegmenter::Create( - std::move(options), - absl::make_unique())); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); + options->base_options.op_resolver = + absl::make_unique(); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); EXPECT_EQ(confidence_masks.size(), 2); @@ -313,15 +289,14 @@ TEST_F(SegmentationTest, SucceedsSelfie144x256Segmentations) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); auto options = std::make_unique(); - options->mutable_segmenter_options()->set_output_type( - SegmenterOptions::CONFIDENCE_MASK); - options->mutable_base_options()->mutable_model_file()->set_file_name( - JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata)); - MP_ASSERT_OK_AND_ASSIGN( - std::unique_ptr segmenter, - ImageSegmenter::Create( - std::move(options), - absl::make_unique())); + options->base_options.model_file_name = + JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); + options->base_options.op_resolver = + absl::make_unique(); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::NONE; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); EXPECT_EQ(confidence_masks.size(), 1); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD new file mode 100644 index 000000000..b9b8ea436 --- /dev/null +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "image_segmenter_options_proto", + srcs = ["image_segmenter_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components:segmenter_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.proto b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto similarity index 91% rename from mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.proto rename to mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto index ab8ff7c83..fcb2914cf 100644 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.vision.image_segmenter.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/segmenter_options.proto"; @@ -25,7 +25,7 @@ message ImageSegmenterOptions { extend mediapipe.CalculatorOptions { optional ImageSegmenterOptions ext = 458105758; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 6f23e9b52..e98013223 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -36,10 +36,19 @@ namespace vision { // The options for configuring a mediapipe object detector task. struct ObjectDetectorOptions { - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, op resolver, etc. tasks::core::BaseOptions base_options; + // The running mode of the task. Default to the image mode. + // Object detector has three running modes: + // 1) The image mode for detecting objects on single image inputs. + // 2) The video mode for detecting objects on the decoded frames of a video. + // 3) The live stream mode for detecting objects on the live stream of input + // data, such as from camera. In this mode, the "result_callback" below must + // be specified to receive the detection results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + // The locale to use for display names specified through the TFLite Model // Metadata, if any. Defaults to English. std::string display_names_locale = "en"; @@ -65,15 +74,6 @@ struct ObjectDetectorOptions { // category names are ignored. Mutually exclusive with category_allowlist. std::vector category_denylist = {}; - // The running mode of the task. Default to the image mode. - // Object detector has three running modes: - // 1) The image mode for detecting objects on single image inputs. - // 2) The video mode for detecting objects on the decoded frames of a video. - // 3) The live stream mode for detecting objects on the live stream of input - // data, such as from camera. In this mode, the "result_callback" below must - // be specified to receive the detection results asynchronously. - core::RunningMode running_mode = core::RunningMode::IMAGE; - // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index e5f441731..94f217378 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -531,9 +531,9 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { // Outputs the labeled detections and the processed image as the subgraph // output streams. return {{ - .detections = - detection_label_id_to_text[Output>("")], - .image = preprocessing[Output(kImageTag)], + /* detections= */ + detection_label_id_to_text[Output>("")], + /* image= */ preprocessing[Output(kImageTag)], }}; } }; diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 9825b2c3d..faca6ef24 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -194,7 +194,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { // interpreter errors (e.g., "Encountered unresolved custom op"). EXPECT_EQ(object_detector.status().code(), absl::StatusCode::kInternal); EXPECT_THAT(object_detector.status().message(), - HasSubstr("interpreter_->AllocateTensors() == kTfLiteOk")); + HasSubstr("interpreter->AllocateTensors() == kTfLiteOk")); } TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index 5e2955a9f..37edab1d9 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -27,7 +27,7 @@ message ObjectDetectorOptions { extend mediapipe.CalculatorOptions { optional ObjectDetectorOptions ext = 443442058; } - // Base options for configuring Task library, such as specifying the TfLite + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.cc b/mediapipe/tasks/cc/vision/segmentation/image_segmenter.cc deleted file mode 100644 index efed5685f..000000000 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter.h" - -#include "mediapipe/framework/api2/builder.h" -#include "mediapipe/tasks/cc/core/task_api_factory.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -namespace { - -constexpr char kSegmentationStreamName[] = "segmented_mask_out"; -constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; -constexpr char kImageStreamName[] = "image_in"; -constexpr char kImageTag[] = "IMAGE"; -constexpr char kSubgraphTypeName[] = - "mediapipe.tasks.vision.ImageSegmenterGraph"; - -using ::mediapipe::CalculatorGraphConfig; -using ::mediapipe::Image; - -// Creates a MediaPipe graph config that only contains a single subgraph node of -// "mediapipe.tasks.vision.SegmenterGraph". -CalculatorGraphConfig CreateGraphConfig( - std::unique_ptr options) { - api2::builder::Graph graph; - auto& subgraph = graph.AddNode(kSubgraphTypeName); - subgraph.GetOptions().Swap(options.get()); - graph.In(kImageTag).SetName(kImageStreamName) >> subgraph.In(kImageTag); - subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> - graph.Out(kGroupedSegmentationTag); - return graph.GetConfig(); -} - -} // namespace - -absl::StatusOr> ImageSegmenter::Create( - std::unique_ptr options, - std::unique_ptr resolver) { - return core::TaskApiFactory::Create( - CreateGraphConfig(std::move(options)), std::move(resolver)); -} - -absl::StatusOr> ImageSegmenter::Segment( - mediapipe::Image image) { - if (image.UsesGpu()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat("GPU input images are currently not supported."), - MediaPipeTasksStatus::kRunnerUnexpectedInputError); - } - ASSIGN_OR_RETURN( - auto output_packets, - runner_->Process({{kImageStreamName, - mediapipe::MakePacket(std::move(image))}})); - return output_packets[kSegmentationStreamName].Get>(); -} - -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.h b/mediapipe/tasks/cc/vision/segmentation/image_segmenter.h deleted file mode 100644 index 58da9feaf..000000000 --- a/mediapipe/tasks/cc/vision/segmentation/image_segmenter.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_ -#define MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_ - -#include -#include -#include - -#include "absl/memory/memory.h" -#include "absl/status/statusor.h" -#include "mediapipe/framework/formats/image.h" -#include "mediapipe/tasks/cc/core/base_task_api.h" -#include "mediapipe/tasks/cc/vision/segmentation/image_segmenter_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/kernels/register.h" - -namespace mediapipe { -namespace tasks { -namespace vision { - -// Performs segmentation on images. -// -// The API expects a TFLite model with mandatory TFLite Model Metadata. -// -// Input tensor: -// (kTfLiteUInt8/kTfLiteFloat32) -// - image input of size `[batch x height x width x channels]`. -// - batch inference is not supported (`batch` is required to be 1). -// - RGB and greyscale inputs are supported (`channels` is required to be -// 1 or 3). -// - if type is kTfLiteFloat32, NormalizationOptions are required to be -// attached to the metadata for input normalization. -// Output tensors: -// (kTfLiteUInt8/kTfLiteFloat32) -// - list of segmented masks. -// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. -// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size -// `cahnnels`. -// - batch is always 1 -// An example of such model can be found at: -// https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 -class ImageSegmenter : core::BaseTaskApi { - public: - using BaseTaskApi::BaseTaskApi; - - // Creates a Segmenter from the provided options. A non-default - // OpResolver can be specified in order to support custom Ops or specify a - // subset of built-in Ops. - static absl::StatusOr> Create( - std::unique_ptr options, - std::unique_ptr resolver = - absl::make_unique()); - - // Runs the actual segmentation task. - absl::StatusOr> Segment(mediapipe::Image image); -}; - -} // namespace vision -} // namespace tasks -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_VISION_SEGMENTATION_IMAGE_SEGMENTER_H_ diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 2d13eab9c..b52604c2b 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -80,6 +80,7 @@ filegroup( ], ) +# TODO Create individual filegroup for models required for each Tasks. filegroup( name = "test_models", srcs = [