diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 93f2dbd06..4fce8bf8e 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -320,6 +320,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:c_api_types", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], ) diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc index 81edb34e0..1d216daf3 100644 --- a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc @@ -21,8 +21,10 @@ #include "absl/status/statusor.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/ret_check.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/string_util.h" namespace mediapipe { @@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor, std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); } +template <> +void CopyTensorBufferToInterpreter(const Tensor& input_tensor, + tflite::Interpreter* interpreter, + int input_tensor_index) { + const char* input_tensor_buffer = + input_tensor.GetCpuReadView().buffer(); + tflite::DynamicBuffer dynamic_buffer; + dynamic_buffer.AddString(input_tensor_buffer, + input_tensor.shape().num_elements()); + dynamic_buffer.WriteToTensorAsVector( + interpreter->tensor(interpreter->inputs()[input_tensor_index])); +} + template void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, int output_tensor_index, @@ -87,13 +102,13 @@ absl::StatusOr> InferenceInterpreterDelegateRunner::Run( break; } case TfLiteType::kTfLiteUInt8: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); break; } case TfLiteType::kTfLiteInt8: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); break; } case TfLiteType::kTfLiteInt32: { @@ -101,6 +116,14 @@ absl::StatusOr> InferenceInterpreterDelegateRunner::Run( interpreter_.get(), i); break; } + case TfLiteType::kTfLiteString: { + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); + break; + } + case TfLiteType::kTfLiteBool: + // No current use-case for copying MediaPipe Tensors with bool type to + // TfLiteTensors. default: return absl::InvalidArgumentError( absl::StrCat("Unsupported input tensor type:", input_tensor_type)); @@ -146,6 +169,15 @@ absl::StatusOr> InferenceInterpreterDelegateRunner::Run( CopyTensorBufferFromInterpreter(interpreter_.get(), i, &output_tensors.back()); break; + case TfLiteType::kTfLiteBool: + output_tensors.emplace_back(Tensor::ElementType::kBool, shape, + Tensor::QuantizationParameters{1.0f, 0}); + CopyTensorBufferFromInterpreter(interpreter_.get(), i, + &output_tensors.back()); + break; + case TfLiteType::kTfLiteString: + // No current use-case for copying TfLiteTensors with string type to + // MediaPipe Tensors. default: return absl::InvalidArgumentError( absl::StrCat("Unsupported output tensor type:", diff --git a/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc b/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc index 0b7e6f082..3d364a53c 100644 --- a/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc @@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) { case Tensor::ElementType::kInt8: Dequantize(input_tensor, &output_tensors->back()); break; + case Tensor::ElementType::kBool: + Dequantize(input_tensor, &output_tensors->back()); + break; default: return absl::InvalidArgumentError(absl::StrCat( "Unsupported input tensor type: ", input_tensor.element_type())); diff --git a/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc b/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc index fd41cc763..e0d549123 100644 --- a/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc @@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) { ValidateResult(GetOutput(), {-1.007874, 0, 1}); } +TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) { + std::vector tensor = {true, false, true}; + PushTensor(Tensor::ElementType::kBool, tensor, + Tensor::QuantizationParameters{1.0f, 0}); + + MP_ASSERT_OK(runner_.Run()); + + ValidateResult(GetOutput(), {1, 0, 1}); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index f6487a17a..19c51853c 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1685,10 +1685,3 @@ cc_test( "@com_google_absl//absl/strings:str_format", ], ) - -# Expose the proto source files for building mediapipe AAR. -filegroup( - name = "protos_src", - srcs = glob(["*.proto"]), - visibility = ["//mediapipe:__subpackages__"], -) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index b78014155..11bcd21c6 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -112,6 +112,14 @@ class MultiPort : public Single { std::vector>& vec_; }; +namespace internal_builder { + +template +using AllowCast = std::integral_constant && + !std::is_same_v>; + +} // namespace internal_builder + // These classes wrap references to the underlying source/destination // endpoints, adding type information and the user-visible API. template @@ -122,13 +130,14 @@ class DestinationImpl { explicit DestinationImpl(std::vector>* vec) : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {} - DestinationBase& base_; -}; -template -class MultiDestinationImpl : public MultiPort> { - public: - using MultiPort>::MultiPort; + template {}, int> = 0> + DestinationImpl Cast() { + return DestinationImpl(&base_); + } + + DestinationBase& base_; }; template @@ -171,12 +180,8 @@ class SourceImpl { return AddTarget(dest); } - template - struct AllowCast - : public std::integral_constant && - !std::is_same_v> {}; - - template {}, int> = 0> + template {}, int> = 0> SourceImpl Cast() { return SourceImpl(base_); } @@ -186,12 +191,6 @@ class SourceImpl { SourceBase* base_; }; -template -class MultiSourceImpl : public MultiPort> { - public: - using MultiPort>::MultiPort; -}; - // A source and a destination correspond to an output/input stream on a node, // and a side source and side destination correspond to an output/input side // packet. @@ -201,20 +200,20 @@ class MultiSourceImpl : public MultiPort> { template using Source = SourceImpl; template -using MultiSource = MultiSourceImpl; +using MultiSource = MultiPort>; template using SideSource = SourceImpl; template -using MultiSideSource = MultiSourceImpl; +using MultiSideSource = MultiPort>; template using Destination = DestinationImpl; template using SideDestination = DestinationImpl; template -using MultiDestination = MultiDestinationImpl; +using MultiDestination = MultiPort>; template -using MultiSideDestination = MultiDestinationImpl; +using MultiSideDestination = MultiPort>; class NodeBase { public: diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3244e092d..810c52527 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) { node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); any_type_output.SetName("any_type_output"); + any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); + CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( node { @@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) { output_stream: "ANY_OUTPUT:any_type_output" } input_stream: "GRAPH_ANY_INPUT:__stream_0" + output_stream: "GRAPH_ANY_OUTPUT:any_type_output" )pb"); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 4a509ab69..44d9bda0b 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -334,13 +334,6 @@ mediapipe_register_type( deps = [":landmark_cc_proto"], ) -# Expose the proto source files for building mediapipe AAR. -filegroup( - name = "protos_src", - srcs = glob(["*.proto"]), - visibility = ["//mediapipe:__subpackages__"], -) - cc_library( name = "image", srcs = ["image.cc"], diff --git a/mediapipe/framework/formats/annotation/BUILD b/mediapipe/framework/formats/annotation/BUILD index 2e33f7668..328001e85 100644 --- a/mediapipe/framework/formats/annotation/BUILD +++ b/mediapipe/framework/formats/annotation/BUILD @@ -33,10 +33,3 @@ mediapipe_proto_library( srcs = ["rasterization.proto"], visibility = ["//visibility:public"], ) - -# Expose the proto source files for building mediapipe AAR. -filegroup( - name = "protos_src", - srcs = glob(["*.proto"]), - visibility = ["//mediapipe:__subpackages__"], -) diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 2f2859837..2c535462b 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -97,8 +97,8 @@ class Tensor { kUInt8, kInt8, kInt32, - // TODO: Update the inference runner to handle kTfLiteString. - kChar + kChar, + kBool }; struct Shape { Shape() = default; @@ -330,6 +330,8 @@ class Tensor { return sizeof(int32_t); case ElementType::kChar: return sizeof(char); + case ElementType::kBool: + return sizeof(bool); } } int bytes() const { return shape_.num_elements() * element_size(); } diff --git a/mediapipe/framework/formats/tensor_test.cc b/mediapipe/framework/formats/tensor_test.cc index fe702f66b..44468cb8f 100644 --- a/mediapipe/framework/formats/tensor_test.cc +++ b/mediapipe/framework/formats/tensor_test.cc @@ -29,6 +29,9 @@ TEST(General, TestDataTypes) { Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4}); EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char)); + + Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3}); + EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool)); } TEST(Cpu, TestMemoryAllocation) { diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 106738a49..e54fb2177 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -150,7 +150,7 @@ cc_library( name = "executor_util", srcs = ["executor_util.cc"], hdrs = ["executor_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9b5de0235..e3d36611c 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -1050,7 +1050,7 @@ objc_library( alwayslink = 1, ) -MIN_IOS_VERSION = "9.0" # For thread_local. +MIN_IOS_VERSION = "11.0" test_suite( name = "ios", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 4b64d2231..591d5e4f7 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -184,7 +184,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { EXPECT_THAT( audio_classifier_or.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); diff --git a/mediapipe/tasks/cc/common.h b/mediapipe/tasks/cc/common.h index 62656b7b3..1295177df 100644 --- a/mediapipe/tasks/cc/common.h +++ b/mediapipe/tasks/cc/common.h @@ -65,6 +65,8 @@ enum class MediaPipeTasksStatus { kFileReadError, // I/O error when mmap-ing file. kFileMmapError, + // ZIP I/O error when unpacking the zip file. + kFileZipError, // TensorFlow Lite metadata error codes. diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD new file mode 100644 index 000000000..33d3e4457 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -0,0 +1,31 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "landmarks_detection", + hdrs = ["landmarks_detection.h"], +) + +cc_library( + name = "gesture_recognition_result", + hdrs = ["gesture_recognition_result.h"], + deps = [ + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/containers/gesture_recognition_result.h b/mediapipe/tasks/cc/components/containers/gesture_recognition_result.h new file mode 100644 index 000000000..4e2e8d775 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/gesture_recognition_result.h @@ -0,0 +1,46 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_ + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace containers { + +// The gesture recognition result from GestureRecognizer, where each vector +// element represents a single hand detected in the image. +struct GestureRecognitionResult { + // Recognized hand gestures with sorted order such that the winning label is + // the first item in the list. + std::vector gestures; + // Classification of handedness. + std::vector handedness; + // Detected hand landmarks in normalized image coordinates. + std::vector hand_landmarks; + // Detected hand landmarks in world coordinates. + std::vector hand_world_landmarks; +}; + +} // namespace containers +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/landmarks_detection.h b/mediapipe/tasks/cc/components/containers/landmarks_detection.h new file mode 100644 index 000000000..7339954d8 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/landmarks_detection.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ + +#include + +// Sturcts holding landmarks related data structure for hand landmarker, pose +// detector, face mesher, etc. +namespace mediapipe::tasks::components::containers { + +// x and y are in [0,1] range with origin in top left in input image space. +// If model provides z, z is in the same scale as x. origin is in the center +// of the face. +struct Landmark { + float x; + float y; + float z; +}; + +// [0, 1] range in input image space +struct Bound { + float left; + float top; + float right; + float bottom; +}; + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto index a44fb5b15..a154e5f4e 100644 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -17,6 +17,9 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; +option java_package = "com.google.mediapipe.tasks.components.container.proto"; +option java_outer_classname = "CategoryProto"; + // A single classification result. message Category { // The index of the category in the corresponding label map, usually packed in diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index e0ccad7a1..0f5086b95 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -19,6 +19,9 @@ package mediapipe.tasks.components.containers.proto; import "mediapipe/tasks/cc/components/containers/proto/category.proto"; +option java_package = "com.google.mediapipe.tasks.components.container.proto"; +option java_outer_classname = "ClassificationsProto"; + // List of predicted categories with an optional timestamp. message ClassificationEntry { // The array of predicted categories, usually sorted by descending scores, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index cd5933ee6..b4fbf9669 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -123,15 +123,17 @@ absl::StatusOr GetClassificationHeadsProperties( const auto* tensor = primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i)); if (tensor->type() != tflite::TensorType_FLOAT32 && - tensor->type() != tflite::TensorType_UINT8) { + tensor->type() != tflite::TensorType_UINT8 && + tensor->type() != tflite::TensorType_BOOL) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Expected output tensor at index %d to have type " - "UINT8 or FLOAT32, found %s instead.", + "UINT8 or FLOAT32 or BOOL, found %s instead.", i, tflite::EnumNameTensorType(tensor->type())), MediaPipeTasksStatus::kInvalidOutputTensorTypeError); } - if (tensor->type() == tflite::TensorType_UINT8) { + if (tensor->type() == tflite::TensorType_UINT8 || + tensor->type() == tflite::TensorType_BOOL) { num_quantized_tensors++; } } @@ -282,6 +284,20 @@ absl::Status ConfigureScoreCalibrationIfAny( return absl::OkStatus(); } +void ConfigureClassificationAggregationCalculator( + const ModelMetadataExtractor& metadata_extractor, + ClassificationAggregationCalculatorOptions* options) { + auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr) { + return; + } + for (const auto& metadata : *output_tensors_metadata) { + options->add_head_names(metadata->name()->str()); + } +} + +} // namespace + // Fills in the TensorsToClassificationCalculatorOptions based on the // classifier options and the (optional) output tensor metadata. absl::Status ConfigureTensorsToClassificationCalculator( @@ -333,20 +349,6 @@ absl::Status ConfigureTensorsToClassificationCalculator( return absl::OkStatus(); } -void ConfigureClassificationAggregationCalculator( - const ModelMetadataExtractor& metadata_extractor, - ClassificationAggregationCalculatorOptions* options) { - auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata(); - if (output_tensors_metadata == nullptr) { - return; - } - for (const auto& metadata : *output_tensors_metadata) { - options->add_head_names(metadata->name()->str()); - } -} - -} // namespace - absl::Status ConfigureClassificationPostprocessingGraph( const ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 8aedad46d..be166982d 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -20,6 +20,7 @@ limitations under the License. #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" namespace mediapipe { namespace tasks { @@ -55,6 +56,16 @@ absl::Status ConfigureClassificationPostprocessingGraph( const proto::ClassifierOptions& classifier_options, proto::ClassificationPostprocessingGraphOptions* options); +// Utility function to fill in the TensorsToClassificationCalculatorOptions +// based on the classifier options and the (optional) output tensor metadata. +// This is meant to be used by other graphs that may also rely on this +// calculator. +absl::Status ConfigureTensorsToClassificationCalculator( + const proto::ClassifierOptions& options, + const metadata::ModelMetadataExtractor& metadata_extractor, + int tensor_index, + mediapipe::TensorsToClassificationCalculatorOptions* calculator_options); + } // namespace processors } // namespace components } // namespace tasks diff --git a/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto index 7afbfc14e..12ece7249 100644 --- a/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto @@ -17,6 +17,9 @@ syntax = "proto2"; package mediapipe.tasks.components.processors.proto; +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "ClassifierOptionsProto"; + // Shared options used by all classification tasks. message ClassifierOptions { // The locale to use for display names specified through the TFLite Model diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto index c0c207543..926e3d7fb 100644 --- a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto @@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions { BERT_PREPROCESSOR = 1; // Used for the RegexPreprocessorCalculator. REGEX_PREPROCESSOR = 2; + // Used for the TextToTensorCalculator. + STRING_PREPROCESSOR = 3; } optional PreprocessorType preprocessor_type = 1; diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/text_preprocessing_graph.cc index 2c4c1b866..6aad8fdd5 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/text_preprocessing_graph.cc @@ -65,6 +65,8 @@ absl::StatusOr GetCalculatorNameFromPreprocessorType( return "BertPreprocessorCalculator"; case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: return "RegexPreprocessorCalculator"; + case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + return "TextToTensorCalculator"; } } @@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) { MediaPipeTasksStatus::kInvalidInputTensorTypeError); } if (all_string_tensors) { - // TODO: Support a TextToTensor calculator for string tensors. - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "String tensors are not supported yet", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); + return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; } // Otherwise, all tensors should have type int32 @@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph( TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, GetPreprocessorType(model_resources)); options.set_preprocessor_type(preprocessor_type); - ASSIGN_OR_RETURN( - int max_seq_len, - GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); - options.set_max_seq_len(max_seq_len); + switch (preprocessor_type) { + case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: + case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + break; + } + case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + ASSIGN_OR_RETURN( + int max_seq_len, + GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); + options.set_max_seq_len(max_seq_len); + } + } return absl::OkStatus(); } @@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: { + case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: + case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { break; } case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index 8a219bb80..33dfeca0b 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -92,13 +92,26 @@ absl::Status ExternalFileHandler::MapExternalFile() { #else if (!external_file_.file_content().empty()) { return absl::OkStatus(); + } else if (external_file_.has_file_pointer_meta()) { + if (external_file_.file_pointer_meta().pointer() == 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Need to set the file pointer in external_file.file_pointer_meta."); + } + if (external_file_.file_pointer_meta().length() <= 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "The length of the file in external_file.file_pointer_meta should be " + "positive."); + } + return absl::OkStatus(); } if (external_file_.file_name().empty() && !external_file_.has_file_descriptor_meta()) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - "ExternalFile must specify at least one of 'file_content', 'file_name' " - "or 'file_descriptor_meta'.", + "ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.", MediaPipeTasksStatus::kInvalidArgumentError); } // Obtain file descriptor, offset and size. @@ -196,6 +209,11 @@ absl::Status ExternalFileHandler::MapExternalFile() { absl::string_view ExternalFileHandler::GetFileContent() { if (!external_file_.file_content().empty()) { return external_file_.file_content(); + } else if (external_file_.has_file_pointer_meta()) { + void* ptr = + reinterpret_cast(external_file_.file_pointer_meta().pointer()); + return absl::string_view(static_cast(ptr), + external_file_.file_pointer_meta().length()); } else { return absl::string_view(static_cast(buffer_) + buffer_offset_ - buffer_aligned_offset_, diff --git a/mediapipe/tasks/cc/core/proto/external_file.proto b/mediapipe/tasks/cc/core/proto/external_file.proto index af4a11697..3147a2224 100644 --- a/mediapipe/tasks/cc/core/proto/external_file.proto +++ b/mediapipe/tasks/cc/core/proto/external_file.proto @@ -26,10 +26,11 @@ option java_outer_classname = "ExternalFileProto"; // (1) file contents loaded in `file_content`. // (2) file path in `file_name`. // (3) file descriptor through `file_descriptor_meta` as returned by open(2). +// (4) file pointer and length in memory through `file_pointer_meta`. // // If more than one field of these fields is provided, they are used in this // precedence order. -// Next id: 4 +// Next id: 5 message ExternalFile { // The file contents as a byte array. optional bytes file_content = 1; @@ -40,6 +41,13 @@ message ExternalFile { // The file descriptor to a file opened with open(2), with optional additional // offset and length information. optional FileDescriptorMeta file_descriptor_meta = 3; + + // The pointer points to location of a file in memory. Use the util method, + // `SetExternalFile` in [1], to configure `file_pointer_meta` from a + // `std::string_view` object. + // + // [1]: mediapipe/tasks/cc/metadata/utils/zip_utils.h + optional FilePointerMeta file_pointer_meta = 4; } // A proto defining file descriptor metadata for mapping file into memory using @@ -62,3 +70,14 @@ message FileDescriptorMeta { // offset of a given asset obtained from AssetFileDescriptor#getStartOffset(). optional int64 offset = 3; } + +// The pointer points to location of a file in memory. Make sure the file memory +// that it points locates on the same machine and it outlives this +// FilePointerMeta object. +message FilePointerMeta { + // Memory address of the file in decimal. + optional uint64 pointer = 1; + + // File length. + optional int64 length = 2; +} diff --git a/mediapipe/tasks/cc/metadata/BUILD b/mediapipe/tasks/cc/metadata/BUILD index c3555e4a0..ef32dd184 100644 --- a/mediapipe/tasks/cc/metadata/BUILD +++ b/mediapipe/tasks/cc/metadata/BUILD @@ -19,8 +19,9 @@ cc_library( deps = [ "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/metadata/utils:zip_readonly_mem_file", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -29,7 +30,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@flatbuffers//:runtime_cc", "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - "@zlib//:zlib_minizip", ], ) diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.cc b/mediapipe/tasks/cc/metadata/metadata_extractor.cc index fcec49083..4bc3e8ba0 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.cc +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.cc @@ -17,16 +17,16 @@ limitations under the License. #include +#include "absl/cleanup/cleanup.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" -#include "contrib/minizip/ioapi.h" -#include "contrib/minizip/unzip.h" #include "flatbuffers/flatbuffers.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -53,72 +53,6 @@ const T* GetItemFromVector( } return src_vector->Get(index); } - -// Wrapper function around calls to unzip to avoid repeating conversion logic -// from error code to Status. -absl::Status UnzipErrorToStatus(int error) { - if (error != UNZ_OK) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to read associated file in zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - return absl::OkStatus(); -} - -// Stores a file name, position in zip buffer and size. -struct ZipFileInfo { - std::string name; - ZPOS64_T position; - ZPOS64_T size; -}; - -// Returns the ZipFileInfo corresponding to the current file in the provided -// unzFile object. -absl::StatusOr GetCurrentZipFileInfo(const unzFile& zf) { - // Open file in raw mode, as data is expected to be uncompressed. - int method; - MP_RETURN_IF_ERROR(UnzipErrorToStatus( - unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1))); - if (method != Z_NO_COMPRESSION) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Expected uncompressed zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - - // Get file info a first time to get filename size. - unz_file_info64 file_info; - MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( - zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0, - /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, - /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); - - // Second call to get file name. - auto file_name_size = file_info.size_filename; - char* c_file_name = (char*)malloc(file_name_size); - MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( - zf, &file_info, c_file_name, file_name_size, - /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, - /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); - std::string file_name = std::string(c_file_name, file_name_size); - free(c_file_name); - - // Get position in file. - auto position = unzGetCurrentFileZStreamPos64(zf); - if (position == 0) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to read file in zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - - // Close file and return. - MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf))); - - ZipFileInfo result{}; - result.name = file_name; - result.position = position; - result.size = file_info.uncompressed_size; - return result; -} } // namespace /* static */ @@ -238,47 +172,15 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( absl::Status ModelMetadataExtractor::ExtractAssociatedFiles( const char* buffer_data, size_t buffer_size) { - // Create in-memory read-only zip file. - ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size); - // Open zip. - unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def()); - if (zf == nullptr) { + auto status = + ExtractFilesfromZipFile(buffer_data, buffer_size, &associated_files_); + if (!status.ok() && + absl::StrContains(status.message(), "Unable to open zip archive.")) { // It's OK if it fails: this means there are no associated files with this // model. return absl::OkStatus(); } - // Get number of files. - unz_global_info global_info; - if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to get zip archive info.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - - // Browse through files in archive. - if (global_info.number_entry > 0) { - int error = unzGoToFirstFile(zf); - while (error == UNZ_OK) { - ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf)); - // Store result in map. - associated_files_[zip_file_info.name] = absl::string_view( - buffer_data + zip_file_info.position, zip_file_info.size); - error = unzGoToNextFile(zf); - } - if (error != UNZ_END_OF_LIST_OF_FILE) { - return CreateStatusWithPayload( - StatusCode::kUnknown, - "Unable to read associated file in zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - } - // Close zip. - if (unzClose(zf) != UNZ_OK) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to close zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - return absl::OkStatus(); + return status; } absl::StatusOr ModelMetadataExtractor::GetAssociatedFile( diff --git a/mediapipe/tasks/cc/metadata/utils/BUILD b/mediapipe/tasks/cc/metadata/utils/BUILD index b595eb10f..881b88962 100644 --- a/mediapipe/tasks/cc/metadata/utils/BUILD +++ b/mediapipe/tasks/cc/metadata/utils/BUILD @@ -24,3 +24,20 @@ cc_library( "@zlib//:zlib_minizip", ], ) + +cc_library( + name = "zip_utils", + srcs = ["zip_utils.cc"], + hdrs = ["zip_utils.h"], + deps = [ + ":zip_readonly_mem_file", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@zlib//:zlib_minizip", + ], +) diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.cc b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc new file mode 100644 index 000000000..41d710e14 --- /dev/null +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc @@ -0,0 +1,175 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "contrib/minizip/ioapi.h" +#include "contrib/minizip/unzip.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h" + +namespace mediapipe { +namespace tasks { +namespace metadata { + +namespace { + +using ::absl::StatusCode; + +// Wrapper function around calls to unzip to avoid repeating conversion logic +// from error code to Status. +absl::Status UnzipErrorToStatus(int error) { + if (error != UNZ_OK) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to read the file in zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + return absl::OkStatus(); +} + +// Stores a file name, position in zip buffer and size. +struct ZipFileInfo { + std::string name; + ZPOS64_T position; + ZPOS64_T size; +}; + +// Returns the ZipFileInfo corresponding to the current file in the provided +// unzFile object. +absl::StatusOr GetCurrentZipFileInfo(const unzFile& zf) { + // Open file in raw mode, as data is expected to be uncompressed. + int method; + MP_RETURN_IF_ERROR(UnzipErrorToStatus( + unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1))); + absl::Cleanup unzipper_closer = [zf]() { + auto status = UnzipErrorToStatus(unzCloseCurrentFile(zf)); + if (!status.ok()) { + LOG(ERROR) << "Failed to close the current zip file: " << status; + } + }; + if (method != Z_NO_COMPRESSION) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Expected uncompressed zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + + // Get file info a first time to get filename size. + unz_file_info64 file_info; + MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( + zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0, + /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, + /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); + + // Second call to get file name. + auto file_name_size = file_info.size_filename; + char* c_file_name = (char*)malloc(file_name_size); + MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( + zf, &file_info, c_file_name, file_name_size, + /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, + /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); + std::string file_name = std::string(c_file_name, file_name_size); + free(c_file_name); + + // Get position in file. + auto position = unzGetCurrentFileZStreamPos64(zf); + if (position == 0) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to read file in zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + + // Perform the cleanup manually for error propagation. + std::move(unzipper_closer).Cancel(); + // Close file and return. + MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf))); + + ZipFileInfo result{}; + result.name = file_name; + result.position = position; + result.size = file_info.uncompressed_size; + return result; +} + +} // namespace + +absl::Status ExtractFilesfromZipFile( + const char* buffer_data, const size_t buffer_size, + absl::flat_hash_map* files) { + // Create in-memory read-only zip file. + ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size); + // Open zip. + unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def()); + if (zf == nullptr) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to open zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + absl::Cleanup unzipper_closer = [zf]() { + if (unzClose(zf) != UNZ_OK) { + LOG(ERROR) << "Unable to close zip archive."; + } + }; + // Get number of files. + unz_global_info global_info; + if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to get zip archive info.", + MediaPipeTasksStatus::kFileZipError); + } + + // Browse through files in archive. + if (global_info.number_entry > 0) { + int error = unzGoToFirstFile(zf); + while (error == UNZ_OK) { + ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf)); + // Store result in map. + (*files)[zip_file_info.name] = absl::string_view( + buffer_data + zip_file_info.position, zip_file_info.size); + error = unzGoToNextFile(zf); + } + if (error != UNZ_END_OF_LIST_OF_FILE) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + "Unable to read associated file in zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + } + // Perform the cleanup manually for error propagation. + std::move(unzipper_closer).Cancel(); + // Close zip. + if (unzClose(zf) != UNZ_OK) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to close zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + return absl::OkStatus(); +} + +void SetExternalFile(const std::string_view& file_content, + core::proto::ExternalFile* model_file) { + auto pointer = reinterpret_cast(file_content.data()); + + model_file->mutable_file_pointer_meta()->set_pointer(pointer); + model_file->mutable_file_pointer_meta()->set_length(file_content.length()); +} + +} // namespace metadata +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.h b/mediapipe/tasks/cc/metadata/utils/zip_utils.h new file mode 100644 index 000000000..28708ba6a --- /dev/null +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" + +namespace mediapipe { +namespace tasks { +namespace metadata { + +// Extract files from the zip file. +// Input: Pointer and length of the zip file in memory. +// Outputs: A map with the filename as key and a pointer to the file contents +// as value. The file contents returned by this function are only guaranteed to +// stay valid while buffer_data is alive. +absl::Status ExtractFilesfromZipFile( + const char* buffer_data, const size_t buffer_size, + absl::flat_hash_map* files); + +// Set file_pointer_meta in ExternalFile which is the pointer points to location +// of a file in memory by file_content. +void SetExternalFile(const std::string_view& file_content, + core::proto::ExternalFile* model_file); + +} // namespace metadata +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_ diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index c9319e946..985c25cfb 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -44,8 +44,12 @@ cc_library( name = "hand_gesture_recognizer_graph", srcs = ["hand_gesture_recognizer_graph.cc"], deps = [ + "//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/core:get_vector_item_calculator", "//mediapipe/calculators/tensor:tensor_converter_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator", "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", @@ -55,7 +59,6 @@ cc_library( "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", @@ -67,10 +70,81 @@ cc_library( "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/metadata:metadata_schema_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", ], alwayslink = 1, ) + +cc_library( + name = "gesture_recognizer_graph", + srcs = ["gesture_recognizer_graph.cc"], + deps = [ + ":hand_gesture_recognizer_graph", + "//mediapipe/calculators/core:vector_indices_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.cc"], + hdrs = ["gesture_recognizer.h"], + deps = [ + ":gesture_recognizer_graph", + ":hand_gesture_recognizer_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/containers:gesture_recognition_result", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc new file mode 100644 index 000000000..ca5deee7f --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -0,0 +1,282 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h" + +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace gesture_recognizer { + +namespace { + +using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision:: + gesture_recognizer::proto::GestureRecognizerGraphOptions; + +using ::mediapipe::tasks::components::containers::GestureRecognitionResult; + +constexpr char kHandGestureSubgraphTypeName[] = + "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kHandGesturesTag[] = "HAND_GESTURES"; +constexpr char kHandGesturesStreamName[] = "hand_gestures"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kHandednessStreamName[] = "handedness"; +constexpr char kHandLandmarksTag[] = "LANDMARKS"; +constexpr char kHandLandmarksStreamName[] = "landmarks"; +constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running +// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the +// number of frames in flight. +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options, + bool enable_flow_limiting) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >> + graph.Out(kHandGesturesTag); + subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >> + graph.Out(kHandednessTag); + subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >> + graph.Out(kHandLandmarksTag); + subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >> + graph.Out(kHandWorldLandmarksTag); + subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator(graph, subgraph, {kImageTag}, + kHandGesturesTag); + } + graph.In(kImageTag) >> subgraph.In(kImageTag); + return graph.GetConfig(); +} + +// Converts the user-facing GestureRecognizerOptions struct to the internal +// GestureRecognizerGraphOptions proto. +std::unique_ptr +ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { + auto options_proto = std::make_unique(); + + bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE; + + // TODO remove these workarounds for base options of subgraphs. + // Configure hand detector options. + auto base_options_proto_for_hand_detector = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(options->base_options_for_hand_detector))); + base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode); + auto* hand_detector_graph_options = + options_proto->mutable_hand_landmarker_graph_options() + ->mutable_hand_detector_graph_options(); + hand_detector_graph_options->mutable_base_options()->Swap( + base_options_proto_for_hand_detector.get()); + hand_detector_graph_options->set_num_hands(options->num_hands); + hand_detector_graph_options->set_min_detection_confidence( + options->min_hand_detection_confidence); + + // Configure hand landmark detector options. + auto base_options_proto_for_hand_landmarker = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(options->base_options_for_hand_landmarker))); + base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode); + auto* hand_landmarks_detector_graph_options = + options_proto->mutable_hand_landmarker_graph_options() + ->mutable_hand_landmarks_detector_graph_options(); + hand_landmarks_detector_graph_options->mutable_base_options()->Swap( + base_options_proto_for_hand_landmarker.get()); + hand_landmarks_detector_graph_options->set_min_detection_confidence( + options->min_hand_presence_confidence); + + auto* hand_landmarker_graph_options = + options_proto->mutable_hand_landmarker_graph_options(); + hand_landmarker_graph_options->set_min_tracking_confidence( + options->min_tracking_confidence); + + // Configure hand gesture recognizer options. + auto base_options_proto_for_gesture_recognizer = + std::make_unique( + tasks::core::ConvertBaseOptionsToProto( + &(options->base_options_for_gesture_recognizer))); + base_options_proto_for_gesture_recognizer->set_use_stream_mode( + use_stream_mode); + auto* hand_gesture_recognizer_graph_options = + options_proto->mutable_hand_gesture_recognizer_graph_options(); + hand_gesture_recognizer_graph_options->mutable_base_options()->Swap( + base_options_proto_for_gesture_recognizer.get()); + if (options->min_gesture_confidence >= 0) { + hand_gesture_recognizer_graph_options->mutable_classifier_options() + ->set_score_threshold(options->min_gesture_confidence); + } + return options_proto; +} + +} // namespace + +absl::StatusOr> GestureRecognizer::Create( + std::unique_ptr options) { + auto options_proto = ConvertGestureRecognizerGraphOptionsProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = [=](absl::StatusOr + status_or_packets) { + if (!status_or_packets.ok()) { + Image image; + result_callback(status_or_packets.status(), image, + Timestamp::Unset().Value()); + return; + } + if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { + return; + } + Packet gesture_packet = + status_or_packets.value()[kHandGesturesStreamName]; + Packet handedness_packet = + status_or_packets.value()[kHandednessStreamName]; + Packet hand_landmarks_packet = + status_or_packets.value()[kHandLandmarksStreamName]; + Packet hand_world_landmarks_packet = + status_or_packets.value()[kHandWorldLandmarksStreamName]; + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + result_callback( + {{gesture_packet.Get>(), + handedness_packet.Get>(), + hand_landmarks_packet.Get>(), + hand_world_landmarks_packet.Get>()}}, + image_packet.Get(), + gesture_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + }; + } + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr GestureRecognizer::Recognize( + mediapipe::Image image) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN(auto output_packets, + ProcessImageData({{kImageInStreamName, + MakePacket(std::move(image))}})); + return { + {/* gestures= */ {output_packets[kHandGesturesStreamName] + .Get>()}, + /* handedness= */ + {output_packets[kHandednessStreamName] + .Get>()}, + /* hand_landmarks= */ + {output_packets[kHandLandmarksStreamName] + .Get>()}, + /* hand_world_landmarks */ + {output_packets[kHandWorldLandmarksStreamName] + .Get>()}}, + }; +} + +absl::StatusOr GestureRecognizer::RecognizeForVideo( + mediapipe::Image image, int64 timestamp_ms) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + return { + {/* gestures= */ {output_packets[kHandGesturesStreamName] + .Get>()}, + /* handedness= */ + {output_packets[kHandednessStreamName] + .Get>()}, + /* hand_landmarks= */ + {output_packets[kHandLandmarksStreamName] + .Get>()}, + /* hand_world_landmarks */ + {output_packets[kHandWorldLandmarksStreamName] + .Get>()}}, + }; +} + +absl::Status GestureRecognizer::RecognizeAsync(mediapipe::Image image, + int64 timestamp_ms) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +} // namespace gesture_recognizer +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h new file mode 100644 index 000000000..17c9cc921 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -0,0 +1,172 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace gesture_recognizer { + +struct GestureRecognizerOptions { + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // TODO: remove these. Temporary solutions before bundle asset is + // ready. + tasks::core::BaseOptions base_options_for_hand_landmarker; + tasks::core::BaseOptions base_options_for_hand_detector; + tasks::core::BaseOptions base_options_for_gesture_recognizer; + + // The running mode of the task. Default to the image mode. + // GestureRecognizer has three running modes: + // 1) The image mode for recognizing hand gestures on single image inputs. + // 2) The video mode for recognizing hand gestures on the decoded frames of a + // video. + // 3) The live stream mode for recognizing hand gestures on the live stream of + // input data, such as from camera. In this mode, the "result_callback" + // below must be specified to receive the detection results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // The maximum number of hands can be detected by the GestureRecognizer. + int num_hands = 1; + + // The minimum confidence score for the hand detection to be considered + // successfully. + float min_hand_detection_confidence = 0.5; + + // The minimum confidence score of hand presence score in the hand landmark + // detection. + float min_hand_presence_confidence = 0.5; + + // The minimum confidence score for the hand tracking to be considered + // successfully. + float min_tracking_confidence = 0.5; + + // The minimum confidence score for the gestures to be considered + // successfully. If < 0, the gesture confidence thresholds in the model + // metadata are used. + // TODO Note this option is subject to change, after scoring + // merging calculator is implemented. + float min_gesture_confidence = -1; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function, + const Image&, int64)> + result_callback = nullptr; +}; + +// Performs hand gesture recognition on the given image. +// +// TODO add the link to DevSite. +// This API expects expects a pre-trained hand gesture model asset bundle, or a +// custom one created using Model Maker. See . +// +// Inputs: +// Image +// - The image that gesture recognition runs on. +// Outputs: +// GestureRecognitionResult +// - The hand gesture recognition results. +class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates a GestureRecognizer from a GestureRecognizerhOptions to process + // image data or streaming data. Gesture recognizer can be created with one of + // the following three running modes: + // 1) Image mode for recognizing gestures on single image inputs. + // Users provide mediapipe::Image to the `Recognize` method, and will + // receive the recognized hand gesture results as the return value. + // 2) Video mode for recognizing gestures on the decoded frames of a video. + // 3) Live stream mode for recognizing gestures on the live stream of the + // input data, such as from camera. Users call `RecognizeAsync` to push the + // image data into the GestureRecognizer, the recognized results along with + // the input timestamp and the image that gesture recognizer runs on will + // be available in the result callback when the gesture recognizer finishes + // the work. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs hand gesture recognition on the given image. + // Only use this method when the GestureRecognizer is created with the image + // running mode. + // + // image - mediapipe::Image + // Image to perform hand gesture recognition on. + // + // The image can be of any size with format RGB or RGBA. + // TODO: Describes how the input image will be preprocessed + // after the yuv support is implemented. + absl::StatusOr Recognize( + Image image); + + // Performs gesture recognition on the provided video frame. + // Only use this method when the GestureRecognizer is created with the video + // running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + absl::StatusOr + RecognizeForVideo(Image image, int64 timestamp_ms); + + // Sends live image data to perform gesture recognition, and the results will + // be available via the "result_callback" provided in the + // GestureRecognizerOptions. Only use this method when the GestureRecognizer + // is created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the gesture recognizer. The input timestamps must be monotonically + // increasing. + // + // The "result_callback" provides + // - A vector of GestureRecognitionResult, each is the recognized results + // for a input frame. + // - The const reference to the corresponding input image that the gesture + // recognizer runs on. Note that the const reference to the image will no + // longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status RecognizeAsync(Image image, int64 timestamp_ms); + + // Shuts down the GestureRecognizer when all works are done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace gesture_recognizer +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc new file mode 100644 index 000000000..b4f2af4d8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -0,0 +1,211 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace gesture_recognizer { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::vision::gesture_recognizer::proto:: + GestureRecognizerGraphOptions; +using ::mediapipe::tasks::vision::gesture_recognizer::proto:: + HandGestureRecognizerGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kHandGesturesTag[] = "HAND_GESTURES"; +constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; + +struct GestureRecognizerOutputs { + Source> gesture; + Source> handedness; + Source> hand_landmarks; + Source> hand_world_landmarks; + Source image; +}; + +} // namespace + +// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs +// hand gesture recognition. +// +// Inputs: +// IMAGE - Image +// Image to perform hand gesture recognition on. +// +// Outputs: +// HAND_GESTURES - std::vector +// Recognized hand gestures with sorted order such that the winning label is +// the first item in the list. +// LANDMARKS: - std::vector +// Detected hand landmarks. +// WORLD_LANDMARKS - std::vector +// Detected hand landmarks in world coordinates. +// HAND_RECT_NEXT_FRAME - std::vector +// The predicted Rect enclosing the hand RoI for landmark detection on the +// next frame. +// HANDEDNESS - std::vector +// Classification of handedness. +// IMAGE - mediapipe::Image +// The image that gesture recognizer runs on and has the pixel data stored +// on the target storage (CPU vs GPU). +// +// +// Example: +// node { +// calculator: +// "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" +// input_stream: "IMAGE:image_in" +// output_stream: "HAND_GESTURES:hand_gestures" +// output_stream: "LANDMARKS:hand_landmarks" +// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" +// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" +// output_stream: "HANDEDNESS:handedness" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.vision.gesture_recognizer.proto.GestureRecognizerGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "hand_gesture.tflite" +// } +// } +// hand_landmark_detector_options { +// base_options { +// model_asset { +// file_name: "hand_landmark.tflite" +// } +// } +// } +// } +// } +// } +class GestureRecognizerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, + BuildGestureRecognizerGraph( + *sc->MutableOptions(), + graph[Input(kImageTag)], graph)); + hand_gesture_recognition_output.gesture >> + graph[Output>(kHandGesturesTag)]; + hand_gesture_recognition_output.handedness >> + graph[Output>(kHandednessTag)]; + hand_gesture_recognition_output.hand_landmarks >> + graph[Output>(kLandmarksTag)]; + hand_gesture_recognition_output.hand_world_landmarks >> + graph[Output>(kWorldLandmarksTag)]; + hand_gesture_recognition_output.image >> graph[Output(kImageTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr BuildGestureRecognizerGraph( + GestureRecognizerGraphOptions& graph_options, Source image_in, + Graph& graph) { + auto& image_property = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_property.In("IMAGE"); + auto image_size = image_property.Out("SIZE"); + + // Hand landmarker graph. + auto& hand_landmarker_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"); + auto& hand_landmarker_graph_options = + hand_landmarker_graph.GetOptions(); + hand_landmarker_graph_options.Swap( + graph_options.mutable_hand_landmarker_graph_options()); + + image_in >> hand_landmarker_graph.In(kImageTag); + auto hand_landmarks = + hand_landmarker_graph[Output>( + kLandmarksTag)]; + auto hand_world_landmarks = + hand_landmarker_graph[Output>( + kWorldLandmarksTag)]; + auto handedness = + hand_landmarker_graph[Output>( + kHandednessTag)]; + + auto& vector_indices = + graph.AddNode("NormalizedLandmarkListVectorIndicesCalculator"); + hand_landmarks >> vector_indices.In("VECTOR"); + auto hand_landmarks_id = vector_indices.Out("INDICES"); + + // Hand gesture recognizer subgraph. + auto& hand_gesture_subgraph = graph.AddNode( + "mediapipe.tasks.vision.gesture_recognizer." + "MultipleHandGestureRecognizerGraph"); + hand_gesture_subgraph.GetOptions().Swap( + graph_options.mutable_hand_gesture_recognizer_graph_options()); + hand_landmarks >> hand_gesture_subgraph.In(kLandmarksTag); + hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag); + handedness >> hand_gesture_subgraph.In(kHandednessTag); + image_size >> hand_gesture_subgraph.In(kImageSizeTag); + hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag); + auto hand_gestures = + hand_gesture_subgraph[Output>( + kHandGesturesTag)]; + + return {{.gesture = hand_gestures, + .handedness = handedness, + .hand_landmarks = hand_landmarks, + .hand_world_landmarks = hand_world_landmarks, + .image = hand_landmarker_graph[Output(kImageTag)]}}; + } +}; + +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::gesture_recognizer::GestureRecognizerGraph); // NOLINT +// clang-format on + +} // namespace gesture_recognizer +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 05bc607ae..8d7e0bc07 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -27,7 +27,6 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -36,7 +35,6 @@ limitations under the License. #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" -#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" namespace mediapipe { @@ -50,7 +48,8 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::processors:: + ConfigureTensorsToClassificationCalculator; using ::mediapipe::tasks::vision::gesture_recognizer::proto:: HandGestureRecognizerGraphOptions; @@ -95,15 +94,14 @@ Source> ConvertMatrixToTensor(Source matrix, // The size of image from which the landmarks detected from. // // Outputs: -// HAND_GESTURES - ClassificationResult +// HAND_GESTURES - ClassificationList // Recognized hand gestures with sorted order such that the winning label is // the first item in the list. // // // Example: // node { -// calculator: -// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph" +// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerGraph" // input_stream: "HANDEDNESS:handedness" // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" @@ -136,12 +134,12 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { graph[Input(kLandmarksTag)], graph[Input(kWorldLandmarksTag)], graph[Input>(kImageSizeTag)], graph)); - hand_gestures >> graph[Output(kHandGesturesTag)]; + hand_gestures >> graph[Output(kHandGesturesTag)]; return graph.GetConfig(); } private: - absl::StatusOr> BuildGestureRecognizerGraph( + absl::StatusOr> BuildGestureRecognizerGraph( const HandGestureRecognizerGraphOptions& graph_options, const core::ModelResources& model_resources, Source handedness, @@ -201,25 +199,24 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { auto concatenated_tensors = concatenate_tensor_vector.Out(""); // Inference for static hand gesture recognition. + // TODO add embedding step. auto& inference = AddInference( model_resources, graph_options.base_options().acceleration(), graph); concatenated_tensors >> inference.In(kTensorsTag); auto inference_output_tensors = inference.Out(kTensorsTag); - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.processors." - "ClassificationPostprocessingGraph"); - MP_RETURN_IF_ERROR( - components::processors::ConfigureClassificationPostprocessingGraph( - model_resources, graph_options.classifier_options(), - &postprocessing - .GetOptions())); - inference_output_tensors >> postprocessing.In(kTensorsTag); - auto classification_result = - postprocessing[Output("CLASSIFICATION_RESULT")]; - - return classification_result; + auto& tensors_to_classification = + graph.AddNode("TensorsToClassificationCalculator"); + MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( + graph_options.classifier_options(), + *model_resources.GetMetadataExtractor(), 0, + &tensors_to_classification.GetOptions< + mediapipe::TensorsToClassificationCalculatorOptions>())); + inference_output_tensors >> tensors_to_classification.In(kTensorsTag); + auto classification_list = + tensors_to_classification[Output( + "CLASSIFICATIONS")]; + return classification_list; } }; @@ -247,9 +244,9 @@ REGISTER_MEDIAPIPE_GRAPH( // index corresponding to the same hand if the graph runs multiple times. // // Outputs: -// HAND_GESTURES - std::vector +// HAND_GESTURES - std::vector // A vector of recognized hand gestures. Each vector element is the -// ClassificationResult of the hand in input vector. +// ClassificationList of the hand in input vector. // // // Example: @@ -288,12 +285,12 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { graph[Input>(kImageSizeTag)], graph[Input>(kHandTrackingIdsTag)], graph)); multi_hand_gestures >> - graph[Output>(kHandGesturesTag)]; + graph[Output>(kHandGesturesTag)]; return graph.GetConfig(); } private: - absl::StatusOr>> + absl::StatusOr>> BuildMultiGestureRecognizerSubraph( const HandGestureRecognizerGraphOptions& graph_options, Source> multi_handedness, @@ -346,12 +343,13 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); - auto& end_loop_classification_results = - graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator"); - batch_end >> end_loop_classification_results.In(kBatchEndTag); - hand_gestures >> end_loop_classification_results.In(kItemTag); - auto multi_hand_gestures = end_loop_classification_results - [Output>(kIterableTag)]; + auto& end_loop_classification_lists = + graph.AddNode("EndLoopClassificationListCalculator"); + batch_end >> end_loop_classification_lists.In(kBatchEndTag); + hand_gestures >> end_loop_classification_lists.In(kItemTag); + auto multi_hand_gestures = + end_loop_classification_lists[Output>( + kIterableTag)]; return multi_hand_gestures; } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD index 7b5c65eab..3b73bf2b0 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD @@ -31,8 +31,8 @@ mediapipe_proto_library( ) mediapipe_proto_library( - name = "hand_gesture_recognizer_graph_options_proto", - srcs = ["hand_gesture_recognizer_graph_options.proto"], + name = "gesture_classifier_graph_options_proto", + srcs = ["gesture_classifier_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -40,3 +40,28 @@ mediapipe_proto_library( "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) + +mediapipe_proto_library( + name = "hand_gesture_recognizer_graph_options_proto", + srcs = ["hand_gesture_recognizer_graph_options.proto"], + deps = [ + ":gesture_classifier_graph_options_proto", + ":gesture_embedder_graph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "gesture_recognizer_graph_options_proto", + srcs = ["gesture_recognizer_graph_options.proto"], + deps = [ + ":hand_gesture_recognizer_graph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto new file mode 100644 index 000000000..7730f005f --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.gesture_recognizer.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +message GestureClassifierGraphOptions { + extend mediapipe.CalculatorOptions { + optional GestureClassifierGraphOptions ext = 478825465; + } + // Base options for configuring hand gesture recognition subgraph, such as + // specifying the TfLite model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + optional components.processors.proto.ClassifierOptions classifier_options = 2; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto new file mode 100644 index 000000000..2afbd507b --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -0,0 +1,43 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.gesture_recognizer.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer"; +option java_outer_classname = "GestureRecognizerGraphOptionsProto"; + +message GestureRecognizerGraphOptions { + extend mediapipe.CalculatorOptions { + optional GestureRecognizerGraphOptions ext = 479097054; + } + // Base options for configuring gesture recognizer graph, such as specifying + // the TfLite model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring hand landmarker graph. + optional hand_landmarker.proto.HandLandmarkerGraphOptions + hand_landmarker_graph_options = 2; + + // Options for configuring hand gesture recognizer graph. + optional HandGestureRecognizerGraphOptions + hand_gesture_recognizer_graph_options = 3; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto index ac8cda15c..f71a6b22f 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -20,6 +20,8 @@ package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; +import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; message HandGestureRecognizerGraphOptions { extend mediapipe.CalculatorOptions { @@ -29,11 +31,18 @@ message HandGestureRecognizerGraphOptions { // specifying the TfLite model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // Options for configuring the gesture classifier behavior, such as score - // threshold, number of results, etc. - optional components.processors.proto.ClassifierOptions classifier_options = 2; + // Options for GestureEmbedder. + optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2; - // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be - // considered tracked successfully - optional float min_tracking_confidence = 3 [default = 0.0]; + // Options for GestureClassifier of default gestures. + optional GestureClassifierGraphOptions + canned_gesture_classifier_graph_options = 3; + + // Options for GestureClassifier of custom gestures. + optional GestureClassifierGraphOptions + custom_gesture_classifier_graph_options = 4; + + // TODO: remove these. Temporary solutions before bundle asset is + // ready. + optional components.processors.proto.ClassifierOptions classifier_options = 5; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index a2bb458db..e8a832bbc 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -80,6 +80,7 @@ cc_library( "//mediapipe/calculators/core:gate_calculator_cc_proto", "//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/util:collection_has_min_size_calculator", "//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto", "//mediapipe/framework/api2:builder", @@ -98,6 +99,7 @@ cc_library( "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_landmarks_deduplication_calculator", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", ], diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD index dea81bae3..3b82153eb 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD @@ -15,7 +15,6 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") package(default_visibility = [ - "//mediapipe/app/xeno:__subpackages__", "//mediapipe/tasks:internal", ]) @@ -46,4 +45,26 @@ cc_library( alwayslink = 1, ) -# TODO: Enable this test +cc_library( + name = "hand_landmarks_deduplication_calculator", + srcs = ["hand_landmarks_deduplication_calculator.cc"], + hdrs = ["hand_landmarks_deduplication_calculator.h"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:landmarks_detection", + "//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder", + "//mediapipe/tasks/cc/vision/utils:landmarks_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc new file mode 100644 index 000000000..8920ea0cb --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -0,0 +1,310 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h" +#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h" +#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" + +namespace mediapipe::api2 { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::Bound; +using ::mediapipe::tasks::vision::utils::CalculateIOU; +using ::mediapipe::tasks::vision::utils::DuplicatesFinder; + +float Distance(const NormalizedLandmark& lm_a, const NormalizedLandmark& lm_b, + int width, int height) { + return std::sqrt(std::pow((lm_a.x() - lm_b.x()) * width, 2) + + std::pow((lm_a.y() - lm_b.y()) * height, 2)); +} + +absl::StatusOr> Distances(const NormalizedLandmarkList& a, + const NormalizedLandmarkList& b, + int width, int height) { + const int num = a.landmark_size(); + RET_CHECK_EQ(b.landmark_size(), num); + std::vector distances; + distances.reserve(num); + for (int i = 0; i < num; ++i) { + const NormalizedLandmark& lm_a = a.landmark(i); + const NormalizedLandmark& lm_b = b.landmark(i); + distances.push_back(Distance(lm_a, lm_b, width, height)); + } + return distances; +} + +// Calculates a baseline distance of a hand that can be used as a relative +// measure when calculating hand to hand similarity. +// +// Calculated as maximum of distances: 0->5, 5->17, 17->0, where 0, 5, 17 key +// points are depicted below: +// +// /Middle/ +// | +// /Index/ | /Ring/ +// | | | /Pinky/ +// V V V | +// V +// [8] [12] [16] +// | | | [20] +// | | | | +// /Thumb/ | | | | +// | [7] [11] [15] [19] +// V | | | | +// | | | | +// [4] | | | | +// | [6] [10] [14] [18] +// | | | | | +// | | | | | +// [3] | | | | +// | [5]----[9]---[13]---[17] +// . | | +// \ . | +// \ / | +// [2] | +// \ | +// \ | +// \ | +// [1] . +// \ / +// \ / +// ._____[0]_____. +// +// ^ +// | +// /Wrist/ +absl::StatusOr HandBaselineDistance( + const NormalizedLandmarkList& landmarks, int width, int height) { + RET_CHECK_EQ(landmarks.landmark_size(), 21); // Num of hand landmarks. + constexpr int kWrist = 0; + constexpr int kIndexFingerMcp = 5; + constexpr int kPinkyMcp = 17; + float distance = Distance(landmarks.landmark(kWrist), + landmarks.landmark(kIndexFingerMcp), width, height); + distance = std::max(distance, + Distance(landmarks.landmark(kIndexFingerMcp), + landmarks.landmark(kPinkyMcp), width, height)); + distance = + std::max(distance, Distance(landmarks.landmark(kPinkyMcp), + landmarks.landmark(kWrist), width, height)); + return distance; +} + +Bound CalculateBound(const NormalizedLandmarkList& list) { + constexpr float kMinInitialValue = std::numeric_limits::max(); + constexpr float kMaxInitialValue = std::numeric_limits::lowest(); + + // Compute min and max values on landmarks (they will form + // bounding box) + float bounding_box_left = kMinInitialValue; + float bounding_box_top = kMinInitialValue; + float bounding_box_right = kMaxInitialValue; + float bounding_box_bottom = kMaxInitialValue; + for (const auto& landmark : list.landmark()) { + bounding_box_left = std::min(bounding_box_left, landmark.x()); + bounding_box_top = std::min(bounding_box_top, landmark.y()); + bounding_box_right = std::max(bounding_box_right, landmark.x()); + bounding_box_bottom = std::max(bounding_box_bottom, landmark.y()); + } + + // Populate normalized non rotated face bounding box + return {.left = bounding_box_left, + .top = bounding_box_top, + .right = bounding_box_right, + .bottom = bounding_box_bottom}; +} + +// Uses IoU and distance of some corresponding hand landmarks to detect +// duplicate / similar hands. IoU, distance thresholds, number of landmarks to +// match are found experimentally. Evaluated: +// - manually comparing side by side, before and after deduplication applied +// - generating gesture dataset, and checking select frames in baseline and +// "deduplicated" dataset +// - by confirming gesture training is better with use of deduplication using +// selected thresholds +class HandDuplicatesFinder : public DuplicatesFinder { + public: + explicit HandDuplicatesFinder(bool start_from_the_end) + : start_from_the_end_(start_from_the_end) {} + + absl::StatusOr> FindDuplicates( + const std::vector& multi_landmarks, + int input_width, int input_height) override { + absl::flat_hash_set retained_indices; + absl::flat_hash_set suppressed_indices; + + const int num = multi_landmarks.size(); + std::vector baseline_distances; + baseline_distances.reserve(num); + std::vector bounds; + bounds.reserve(num); + for (const NormalizedLandmarkList& list : multi_landmarks) { + ASSIGN_OR_RETURN(const float baseline_distance, + HandBaselineDistance(list, input_width, input_height)); + baseline_distances.push_back(baseline_distance); + bounds.push_back(CalculateBound(list)); + } + + for (int index = 0; index < num; ++index) { + const int i = start_from_the_end_ ? num - index - 1 : index; + const float stable_distance_i = baseline_distances[i]; + bool suppressed = false; + for (int j : retained_indices) { + const float stable_distance_j = baseline_distances[j]; + + constexpr float kAllowedBaselineDistanceRatio = 0.2f; + const float distance_threshold = + std::max(stable_distance_i, stable_distance_j) * + kAllowedBaselineDistanceRatio; + + ASSIGN_OR_RETURN(const std::vector distances, + Distances(multi_landmarks[i], multi_landmarks[j], + input_width, input_height)); + const int num_matched_landmarks = absl::c_count_if( + distances, + [&](float distance) { return distance < distance_threshold; }); + + const float iou = CalculateIOU(bounds[i], bounds[j]); + + constexpr int kNumMatchedLandmarksToSuppressHand = 10; // out of 21 + constexpr float kMinIouThresholdToSuppressHand = 0.2f; + if (num_matched_landmarks >= kNumMatchedLandmarksToSuppressHand && + iou > kMinIouThresholdToSuppressHand) { + suppressed = true; + break; + } + } + + if (suppressed) { + suppressed_indices.insert(i); + } else { + retained_indices.insert(i); + } + } + return suppressed_indices; + } + + private: + const bool start_from_the_end_; +}; + +template +absl::StatusOr> +VerifyNumAndMaybeInitOutput(const InputPortT& port, CalculatorContext* cc, + int num_expected_size) { + absl::optional output; + if (port(cc).IsConnected() && !port(cc).IsEmpty()) { + RET_CHECK_EQ(port(cc).Get().size(), num_expected_size); + typename InputPortT::PayloadT result; + return {{result}}; + } + return {absl::nullopt}; +} +} // namespace + +std::unique_ptr CreateHandDuplicatesFinder( + bool start_from_the_end) { + return absl::make_unique(start_from_the_end); +} + +absl::Status HandLandmarksDeduplicationCalculator::Process( + mediapipe::CalculatorContext* cc) { + if (kInLandmarks(cc).IsEmpty()) return absl::OkStatus(); + if (kInSize(cc).IsEmpty()) return absl::OkStatus(); + + const std::vector& in_landmarks = *kInLandmarks(cc); + const std::pair& image_size = *kInSize(cc); + + std::unique_ptr duplicates_finder = + CreateHandDuplicatesFinder(/*start_from_the_end=*/false); + ASSIGN_OR_RETURN(absl::flat_hash_set indices_to_remove, + duplicates_finder->FindDuplicates( + in_landmarks, image_size.first, image_size.second)); + + if (indices_to_remove.empty()) { + kOutLandmarks(cc).Send(kInLandmarks(cc)); + kOutRois(cc).Send(kInRois(cc)); + kOutWorldLandmarks(cc).Send(kInWorldLandmarks(cc)); + kOutClassifications(cc).Send(kInClassifications(cc)); + } else { + std::vector out_landmarks; + const int num = in_landmarks.size(); + + ASSIGN_OR_RETURN(absl::optional> out_rois, + VerifyNumAndMaybeInitOutput(kInRois, cc, num)); + ASSIGN_OR_RETURN( + absl::optional> out_world_landmarks, + VerifyNumAndMaybeInitOutput(kInWorldLandmarks, cc, num)); + ASSIGN_OR_RETURN( + absl::optional> out_classifications, + VerifyNumAndMaybeInitOutput(kInClassifications, cc, num)); + + for (int i = 0; i < num; ++i) { + if (indices_to_remove.find(i) != indices_to_remove.end()) continue; + + out_landmarks.push_back(in_landmarks[i]); + if (out_rois) { + out_rois->push_back(kInRois(cc).Get()[i]); + } + if (out_world_landmarks) { + out_world_landmarks->push_back(kInWorldLandmarks(cc).Get()[i]); + } + if (out_classifications) { + out_classifications->push_back(kInClassifications(cc).Get()[i]); + } + } + + if (!out_landmarks.empty()) { + kOutLandmarks(cc).Send(std::move(out_landmarks)); + } + if (out_rois && !out_rois->empty()) { + kOutRois(cc).Send(std::move(out_rois.value())); + } + if (out_world_landmarks && !out_world_landmarks->empty()) { + kOutWorldLandmarks(cc).Send(std::move(out_world_landmarks.value())); + } + if (out_classifications && !out_classifications->empty()) { + kOutClassifications(cc).Send(std::move(out_classifications.value())); + } + } + return absl::OkStatus(); +} +MEDIAPIPE_REGISTER_NODE(HandLandmarksDeduplicationCalculator); + +} // namespace mediapipe::api2 diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h new file mode 100644 index 000000000..d7b435487 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h @@ -0,0 +1,97 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_ + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h" + +namespace mediapipe::api2 { + +// Create a DuplicatesFinder dedicated for finding hand duplications. +std::unique_ptr +CreateHandDuplicatesFinder(bool start_from_the_end = false); + +// Filter duplicate hand landmarks by finding the overlapped hands. +// Inputs: +// MULTI_LANDMARKS - std::vector +// The hand landmarks to be filtered. +// MULTI_ROIS - std::vector +// The regions where each encloses the landmarks of a single hand. +// MULTI_WORLD_LANDMARKS - std::vector +// The hand landmarks to be filtered in world coordinates. +// MULTI_CLASSIFICATIONS - std::vector +// The handedness of hands. +// IMAGE_SIZE - std::pair +// The size of the image which the hand landmarks are detected on. +// +// Outputs: +// MULTI_LANDMARKS - std::vector +// The hand landmarks with duplication removed. +// MULTI_ROIS - std::vector +// The regions where each encloses the landmarks of a single hand with +// duplicate hands removed. +// MULTI_WORLD_LANDMARKS - std::vector +// The hand landmarks with duplication removed in world coordinates. +// MULTI_CLASSIFICATIONS - std::vector +// The handedness of hands with duplicate hands removed. +// +// Example: +// node { +// calculator: "HandLandmarksDeduplicationCalculator" +// input_stream: "MULTI_LANDMARKS:landmarks_in" +// input_stream: "MULTI_ROIS:rois_in" +// input_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_in" +// input_stream: "MULTI_CLASSIFICATIONS:handedness_in" +// input_stream: "IMAGE_SIZE:image_size" +// output_stream: "MULTI_LANDMARKS:landmarks_out" +// output_stream: "MULTI_ROIS:rois_out" +// output_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_out" +// output_stream: "MULTI_CLASSIFICATIONS:handedness_out" +// } +class HandLandmarksDeduplicationCalculator : public Node { + public: + constexpr static Input> + kInLandmarks{"MULTI_LANDMARKS"}; + constexpr static Input>::Optional + kInRois{"MULTI_ROIS"}; + constexpr static Input>::Optional + kInWorldLandmarks{"MULTI_WORLD_LANDMARKS"}; + constexpr static Input>::Optional + kInClassifications{"MULTI_CLASSIFICATIONS"}; + constexpr static Input> kInSize{"IMAGE_SIZE"}; + + constexpr static Output> + kOutLandmarks{"MULTI_LANDMARKS"}; + constexpr static Output>::Optional + kOutRois{"MULTI_ROIS"}; + constexpr static Output>::Optional + kOutWorldLandmarks{"MULTI_WORLD_LANDMARKS"}; + constexpr static Output>::Optional + kOutClassifications{"MULTI_CLASSIFICATIONS"}; + MEDIAPIPE_NODE_CONTRACT(kInLandmarks, kInRois, kInWorldLandmarks, + kInClassifications, kInSize, kOutLandmarks, kOutRois, + kOutWorldLandmarks, kOutClassifications); + absl::Status Process(mediapipe::CalculatorContext* cc) override; +}; + +} // namespace mediapipe::api2 + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 949c06520..ab5a453c5 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -247,11 +247,37 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { image_in >> hand_landmarks_detector_graph.In("IMAGE"); clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT"); + auto landmarks = hand_landmarks_detector_graph.Out(kLandmarksTag); + auto world_landmarks = + hand_landmarks_detector_graph.Out(kWorldLandmarksTag); auto hand_rects_for_next_frame = - hand_landmarks_detector_graph[Output>( - kHandRectNextFrameTag)]; + hand_landmarks_detector_graph.Out(kHandRectNextFrameTag); + auto handedness = hand_landmarks_detector_graph.Out(kHandednessTag); + + auto& image_property = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_property.In("IMAGE"); + auto image_size = image_property.Out("SIZE"); + + auto& deduplicate = graph.AddNode("HandLandmarksDeduplicationCalculator"); + landmarks >> deduplicate.In("MULTI_LANDMARKS"); + world_landmarks >> deduplicate.In("MULTI_WORLD_LANDMARKS"); + hand_rects_for_next_frame >> deduplicate.In("MULTI_ROIS"); + handedness >> deduplicate.In("MULTI_CLASSIFICATIONS"); + image_size >> deduplicate.In("IMAGE_SIZE"); + + auto filtered_landmarks = + deduplicate[Output>( + "MULTI_LANDMARKS")]; + auto filtered_world_landmarks = + deduplicate[Output>("MULTI_WORLD_LANDMARKS")]; + auto filtered_hand_rects_for_next_frame = + deduplicate[Output>("MULTI_ROIS")]; + auto filtered_handedness = + deduplicate[Output>( + "MULTI_CLASSIFICATIONS")]; + // Back edge. - hand_rects_for_next_frame >> previous_loopback.In("LOOP"); + filtered_hand_rects_for_next_frame >> previous_loopback.In("LOOP"); // TODO: Replace PassThroughCalculator with a calculator that // converts the pixel data to be stored on the target storage (CPU vs GPU). @@ -259,14 +285,10 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { image_in >> pass_through.In(""); return {{ - /* landmark_lists= */ hand_landmarks_detector_graph - [Output>(kLandmarksTag)], - /* world_landmark_lists= */ - hand_landmarks_detector_graph[Output>( - kWorldLandmarksTag)], - /* hand_rects_next_frame= */ hand_rects_for_next_frame, - hand_landmarks_detector_graph[Output>( - kHandednessTag)], + /* landmark_lists= */ filtered_landmarks, + /* world_landmark_lists= */ filtered_world_landmarks, + /* hand_rects_next_frame= */ filtered_hand_rects_for_next_frame, + /* handedness= */ filtered_handedness, /* palm_rects= */ hand_detector[Output>(kPalmRectsTag)], /* palm_detections */ diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 070a5a034..dcb2fddfc 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -208,7 +208,7 @@ TEST_F(CreateTest, FailsWithMissingModel) { EXPECT_THAT( image_classifier.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index b307a66b6..76315e230 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imageclassifier.proto"; +option java_outer_classname = "ImageClassifierGraphOptionsProto"; + message ImageClassifierGraphOptions { extend mediapipe.CalculatorOptions { optional ImageClassifierGraphOptions ext = 456383383; diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 08a0d6a25..db1019b33 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -140,7 +140,7 @@ TEST_F(CreateTest, FailsWithMissingModel) { EXPECT_THAT( image_embedder.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 1d3f3e786..ab23a725c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -191,7 +191,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { EXPECT_THAT( segmenter_or.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 463c92566..bcc4c95ee 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -208,7 +208,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { EXPECT_THAT( object_detector.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(object_detector.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); diff --git a/mediapipe/tasks/cc/vision/utils/BUILD b/mediapipe/tasks/cc/vision/utils/BUILD index 3e5cfd2e9..c796798df 100644 --- a/mediapipe/tasks/cc/vision/utils/BUILD +++ b/mediapipe/tasks/cc/vision/utils/BUILD @@ -79,3 +79,30 @@ cc_library( "@stblib//:stb_image", ], ) + +cc_library( + name = "landmarks_duplicates_finder", + hdrs = ["landmarks_duplicates_finder.h"], + deps = [ + "//mediapipe/framework/formats:landmark_cc_proto", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "landmarks_utils", + srcs = ["landmarks_utils.cc"], + hdrs = ["landmarks_utils.h"], + deps = ["//mediapipe/tasks/cc/components/containers:landmarks_detection"], +) + +cc_test( + name = "landmarks_utils_test", + srcs = ["landmarks_utils_test.cc"], + deps = [ + ":landmarks_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/containers:landmarks_detection", + ], +) diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h b/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h new file mode 100644 index 000000000..e1632e6f0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::vision::utils { + +class DuplicatesFinder { + public: + virtual ~DuplicatesFinder() = default; + // Returns indices of landmark lists to remove to make @multi_landmarks + // contain different enough (depending on the implementation) landmark lists + // only. + virtual absl::StatusOr> FindDuplicates( + const std::vector& multi_landmarks, + int input_width, int input_height) = 0; +}; + +} // namespace mediapipe::tasks::vision::utils + +#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_ diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc new file mode 100644 index 000000000..5ec898f15 --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -0,0 +1,48 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" + +#include +#include + +namespace mediapipe::tasks::vision::utils { + +using ::mediapipe::tasks::components::containers::Bound; + +float CalculateArea(const Bound& bound) { + return (bound.right - bound.left) * (bound.bottom - bound.top); +} + +float CalculateIntersectionArea(const Bound& a, const Bound& b) { + const float intersection_left = std::max(a.left, b.left); + const float intersection_top = std::max(a.top, b.top); + const float intersection_right = std::min(a.right, b.right); + const float intersection_bottom = std::min(a.bottom, b.bottom); + + return std::max(intersection_bottom - intersection_top, 0.0) * + std::max(intersection_right - intersection_left, 0.0); +} + +float CalculateIOU(const Bound& a, const Bound& b) { + const float area_a = CalculateArea(a); + const float area_b = CalculateArea(b); + if (area_a <= 0 || area_b <= 0) return 0.0; + + const float intersection_area = CalculateIntersectionArea(a, b); + return intersection_area / (area_a + area_b - intersection_area); +} + +} // namespace mediapipe::tasks::vision::utils diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h new file mode 100644 index 000000000..b42eae0b6 --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h" + +namespace mediapipe::tasks::vision::utils { + +// Calculates intersection over union for two bounds. +float CalculateIOU(const components::containers::Bound& a, + const components::containers::Bound& b); + +// Calculates area for face bound +float CalculateArea(const components::containers::Bound& bound); + +// Calucates intersection area of two face bounds +float CalculateIntersectionArea(const components::containers::Bound& a, + const components::containers::Bound& b); +} // namespace mediapipe::tasks::vision::utils + +#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc new file mode 100644 index 000000000..c30a5225b --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe::tasks::vision::utils { +namespace { + +TEST(LandmarkUtilsTest, CalculateIOU) { + // Do not intersect + EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 2, 3, 3})); + // No x intersection + EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 0, 3, 1})); + // No y intersection + EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {0, 2, 1, 3})); + // Full intersection + EXPECT_EQ(1, CalculateIOU({0, 0, 2, 2}, {0, 0, 2, 2})); + + // Union is 4 intersection is 1 + EXPECT_EQ(0.25, CalculateIOU({0, 0, 3, 1}, {2, 0, 4, 1})); + + // Same in by y + EXPECT_EQ(0.25, CalculateIOU({0, 0, 1, 3}, {0, 2, 1, 4})); +} +} // namespace +} // namespace mediapipe::tasks::vision::utils diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD index 8f3c1539c..610bec911 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/BUILD @@ -34,3 +34,32 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +android_library( + name = "classification_entry", + srcs = ["ClassificationEntry.java"], + deps = [ + ":category", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "classifications", + srcs = ["Classifications.java"], + deps = [ + ":classification_entry", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) + +android_library( + name = "landmark", + srcs = ["Landmark.java"], + deps = [ + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java new file mode 100644 index 000000000..8fc1daa03 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/ClassificationEntry.java @@ -0,0 +1,48 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Collections; +import java.util.List; + +/** + * Represents a list of predicted categories with an optional timestamp. Typically used as result + * for classification tasks. + */ +@AutoValue +public abstract class ClassificationEntry { + /** + * Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional + * timestamp. + * + * @param categories the list of {@link Category} objects that contain category name, display + * name, score and label index. + * @param timestampMs the {@link long} representing the timestamp for which these categories were + * obtained. + */ + public static ClassificationEntry create(List categories, long timestampMs) { + return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs); + } + + /** The list of predicted {@link Category} objects, sorted by descending score. */ + public abstract List categories(); + + /** + * The timestamp (in milliseconds) associated to the classification entry. This is useful for time + * series use cases, e.g. audio classification. + */ + public abstract long timestampMs(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java new file mode 100644 index 000000000..726578729 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Classifications.java @@ -0,0 +1,52 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Collections; +import java.util.List; + +/** + * Represents the list of classification for a given classifier head. Typically used as a result for + * classification tasks. + */ +@AutoValue +public abstract class Classifications { + + /** + * Creates a {@link Classifications} instance. + * + * @param entries the list of {@link ClassificationEntry} objects containing the predicted + * categories. + * @param headIndex the index of the classifier head. + * @param headName the name of the classifier head. + */ + public static Classifications create( + List entries, int headIndex, String headName) { + return new AutoValue_Classifications( + Collections.unmodifiableList(entries), headIndex, headName); + } + + /** A list of {@link ClassificationEntry} objects. */ + public abstract List entries(); + + /** + * The index of the classifier head these entries refer to. This is useful for multi-head models. + */ + public abstract int headIndex(); + + /** The name of the classifier head, which is the corresponding tensor metadata name. */ + public abstract String headName(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java new file mode 100644 index 000000000..3f96d7779 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -0,0 +1,66 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.containers; + +import com.google.auto.value.AutoValue; +import java.util.Objects; + +/** + * Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the + * landmark coordinates is normalized respect to the dimension of image, and the coordinates values + * are in the range of [0,1]. Otherwise, it represenet a point in world coordinates. + */ +@AutoValue +public abstract class Landmark { + private static final float TOLERANCE = 1e-6f; + + public static Landmark create(float x, float y, float z, boolean normalized) { + return new AutoValue_Landmark(x, y, z, normalized); + } + + // The x coordniates of the landmark. + public abstract float x(); + + // The y coordniates of the landmark. + public abstract float y(); + + // The z coordniates of the landmark. + public abstract float z(); + + // Whether this landmark is normalized with respect to the image size. + public abstract boolean normalized(); + + @Override + public final boolean equals(Object o) { + if (!(o instanceof Landmark)) { + return false; + } + Landmark other = (Landmark) o; + return other.normalized() == this.normalized() + && Math.abs(other.x() - this.x()) < TOLERANCE + && Math.abs(other.x() - this.y()) < TOLERANCE + && Math.abs(other.x() - this.z()) < TOLERANCE; + } + + @Override + public final int hashCode() { + return Objects.hash(x(), y(), z(), normalized()); + } + + @Override + public final String toString() { + return ""; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD new file mode 100644 index 000000000..88516d806 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "classifieroptions", + srcs = ["ClassifierOptions.java"], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + deps = [ + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/ClassifierOptions.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/ClassifierOptions.java new file mode 100644 index 000000000..76da4b446 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors/ClassifierOptions.java @@ -0,0 +1,118 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.components.processors; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** Classifier options shared across MediaPipe Java classification tasks. */ +@AutoValue +public abstract class ClassifierOptions { + + /** Builder for {@link ClassifierOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. + */ + public abstract Builder setDisplayNamesLocale(String locale); + + /** + * Sets the optional maximum number of top-scored classification results to return. + * + *

If not set, all available results are returned. If set, must be > 0. + */ + public abstract Builder setMaxResults(Integer maxResults); + + /** + * Sets the optional score threshold. Results with score below this value are rejected. + * + *

Overrides the score threshold specified in the TFLite Model Metadata, if any. + */ + public abstract Builder setScoreThreshold(Float scoreThreshold); + + /** + * Sets the optional allowlist of category names. + * + *

If non-empty, detection results whose category name is not in this set will be filtered + * out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryDenylist}. + */ + public abstract Builder setCategoryAllowlist(List categoryAllowlist); + + /** + * Sets the optional denylist of category names. + * + *

If non-empty, detection results whose category name is in this set will be filtered out. + * Duplicate or unknown category names are ignored. Mutually exclusive with {@code + * categoryAllowlist}. + */ + public abstract Builder setCategoryDenylist(List categoryDenylist); + + abstract ClassifierOptions autoBuild(); + + /** + * Validates and builds the {@link ClassifierOptions} instance. + * + * @throws IllegalArgumentException if {@link maxResults} is set to a value <= 0. + */ + public final ClassifierOptions build() { + ClassifierOptions options = autoBuild(); + if (options.maxResults().isPresent() && options.maxResults().get() <= 0) { + throw new IllegalArgumentException("If specified, maxResults must be > 0"); + } + return options; + } + } + + public abstract Optional displayNamesLocale(); + + public abstract Optional maxResults(); + + public abstract Optional scoreThreshold(); + + public abstract List categoryAllowlist(); + + public abstract List categoryDenylist(); + + public static Builder builder() { + return new AutoValue_ClassifierOptions.Builder() + .setCategoryAllowlist(Collections.emptyList()) + .setCategoryDenylist(Collections.emptyList()); + } + + /** + * Converts a {@link ClassifierOptions} object to a {@link + * ClassifierOptionsProto.ClassifierOptions} protobuf message. + */ + public ClassifierOptionsProto.ClassifierOptions convertToProto() { + ClassifierOptionsProto.ClassifierOptions.Builder builder = + ClassifierOptionsProto.ClassifierOptions.newBuilder(); + displayNamesLocale().ifPresent(builder::setDisplayNamesLocale); + maxResults().ifPresent(builder::setMaxResults); + scoreThreshold().ifPresent(builder::setScoreThreshold); + if (!categoryAllowlist().isEmpty()) { + builder.addAllCategoryAllowlist(categoryAllowlist()); + } + if (!categoryDenylist().isEmpty()) { + builder.addAllCategoryDenylist(categoryDenylist()); + } + return builder.build(); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD index 94f77ea68..8df9173b2 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BUILD @@ -19,8 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) android_library( name = "core", srcs = glob(["*.java"]), + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], deps = [ ":libmediapipe_tasks_vision_jni_lib", + "//mediapipe/framework/formats:rect_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", @@ -36,6 +40,7 @@ cc_binary( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java index 92f64e898..7ab8e75a1 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core/BaseVisionTaskApi.java @@ -14,101 +14,247 @@ package com.google.mediapipe.tasks.vision.core; +import android.graphics.RectF; +import com.google.mediapipe.formats.proto.RectProto.NormalizedRect; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.ProtoUtil; import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskRunner; import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** The base class of MediaPipe vision tasks. */ public class BaseVisionTaskApi implements AutoCloseable { private static final long MICROSECONDS_PER_MILLISECOND = 1000; private final TaskRunner runner; private final RunningMode runningMode; + private final String imageStreamName; + private final Optional normRectStreamName; static { System.loadLibrary("mediapipe_tasks_vision_jni"); + ProtoUtil.registerTypeName(NormalizedRect.class, "mediapipe.NormalizedRect"); } /** - * Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision - * task {@link RunningMode}. + * Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input. * * @param runner a {@link TaskRunner}. * @param runningMode a mediapipe vision task {@link RunningMode}. + * @param imageStreamName the name of the input image stream. */ - public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) { + public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) { this.runner = runner; this.runningMode = runningMode; + this.imageStreamName = imageStreamName; + this.normRectStreamName = Optional.empty(); + } + + /** + * Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as + * input. + * + * @param runner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + * @param imageStreamName the name of the input image stream. + * @param normRectStreamName the name of the input normalized rect image stream. + */ + public BaseVisionTaskApi( + TaskRunner runner, + RunningMode runningMode, + String imageStreamName, + String normRectStreamName) { + this.runner = runner; + this.runningMode = runningMode; + this.imageStreamName = imageStreamName; + this.normRectStreamName = Optional.of(normRectStreamName); } /** * A synchronous method to process single image inputs. The call blocks the current thread until a * failure status or a successful result is returned. * - * @param imageStreamName the image input stream name. * @param image a MediaPipe {@link Image} object for processing. - * @throws MediaPipeException if the task is not in the image mode. + * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect + * input. */ - protected TaskResult processImageData(String imageStreamName, Image image) { + protected TaskResult processImageData(Image image) { if (runningMode != RunningMode.IMAGE) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the image mode. Current running mode:" + runningMode.name()); } + if (normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task expects a normalized rect as input."); + } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); return runner.process(inputPackets); } + /** + * A synchronous method to process single image inputs. The call blocks the current thread until a + * failure status or a successful result is returned. + * + * @param image a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates + * are expected to be specified as normalized values in [0,1]. + * @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized + * rect. + */ + protected TaskResult processImageData(Image image, RectF roi) { + if (runningMode != RunningMode.IMAGE) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the image mode. Current running mode:" + + runningMode.name()); + } + if (!normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task doesn't expect a normalized rect as input."); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + inputPackets.put( + normRectStreamName.get(), + runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + return runner.process(inputPackets); + } + /** * A synchronous method to process continuous video frames. The call blocks the current thread * until a failure status or a successful result is returned. * - * @param imageStreamName the image input stream name. * @param image a MediaPipe {@link Image} object for processing. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode. + * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect + * input. */ - protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) { + protected TaskResult processVideoData(Image image, long timestampMs) { if (runningMode != RunningMode.VIDEO) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the video mode. Current running mode:" + runningMode.name()); } + if (normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task expects a normalized rect as input."); + } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } + /** + * A synchronous method to process continuous video frames. The call blocks the current thread + * until a failure status or a successful result is returned. + * + * @param image a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates + * are expected to be specified as normalized values in [0,1]. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized + * rect. + */ + protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) { + if (runningMode != RunningMode.VIDEO) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the video mode. Current running mode:" + + runningMode.name()); + } + if (!normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task doesn't expect a normalized rect as input."); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + inputPackets.put( + normRectStreamName.get(), + runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + /** * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * available in the user-defined result listener. * - * @param imageStreamName the image input stream name. * @param image a MediaPipe {@link Image} object for processing. * @param timestampMs the corresponding timestamp of the input image in milliseconds. - * @throws MediaPipeException if the task is not in the video mode. + * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect + * input. */ - protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) { + protected void sendLiveStreamData(Image image, long timestampMs) { if (runningMode != RunningMode.LIVE_STREAM) { throw new MediaPipeException( MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), "Task is not initialized with the live stream mode. Current running mode:" + runningMode.name()); } + if (normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task expects a normalized rect as input."); + } Map inputPackets = new HashMap<>(); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); } + /** + * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be + * available in the user-defined result listener. + * + * @param image a MediaPipe {@link Image} object for processing. + * @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates + * are expected to be specified as normalized values in [0,1]. + * @param timestampMs the corresponding timestamp of the input image in milliseconds. + * @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized + * rect. + */ + protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) { + if (runningMode != RunningMode.LIVE_STREAM) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the live stream mode. Current running mode:" + + runningMode.name()); + } + if (!normRectStreamName.isPresent()) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task doesn't expect a normalized rect as input."); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); + inputPackets.put( + normRectStreamName.get(), + runner.getPacketCreator().createProto(convertToNormalizedRect(roi))); + runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); + } + /** Closes and cleans up the MediaPipe vision task. */ @Override public void close() { runner.close(); } + + /** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */ + private static NormalizedRect convertToNormalizedRect(RectF rect) { + return NormalizedRect.newBuilder() + .setXCenter(rect.centerX()) + .setYCenter(rect.centerY()) + .setWidth(rect.width()) + .setHeight(rect.height()) + .build(); + } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml new file mode 100644 index 000000000..38f98f1a1 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD new file mode 100644 index 000000000..eb3eca52b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/BUILD @@ -0,0 +1,40 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +android_library( + name = "gesturerecognizer", + srcs = [ + "GestureRecognitionResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = ":AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognitionResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognitionResult.java new file mode 100644 index 000000000..fd764cb18 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognitionResult.java @@ -0,0 +1,128 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.gesturerecognizer; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto.Landmark; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.formats.proto.ClassificationProto.Classification; +import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Represents the gesture recognition results generated by {@link GestureRecognizer}. */ +@AutoValue +public abstract class GestureRecognitionResult implements TaskResult { + + /** + * Creates a {@link GestureRecognitionResult} instance from the lists of landmarks, handedness, + * and gestures protobuf messages. + * + * @param landmarksProto a List of {@link NormalizedLandmarkList} + * @param worldLandmarksProto a List of {@link LandmarkList} + * @param handednessesProto a List of {@link ClassificationList} + * @param gesturesProto a List of {@link ClassificationList} + */ + static GestureRecognitionResult create( + List landmarksProto, + List worldLandmarksProto, + List handednessesProto, + List gesturesProto, + long timestampMs) { + List> multiHandLandmarks = + new ArrayList<>(); + List> multiHandWorldLandmarks = + new ArrayList<>(); + List> multiHandHandednesses = new ArrayList<>(); + List> multiHandGestures = new ArrayList<>(); + for (NormalizedLandmarkList handLandmarksProto : landmarksProto) { + List handLandmarks = + new ArrayList<>(); + multiHandLandmarks.add(handLandmarks); + for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) { + handLandmarks.add( + com.google.mediapipe.tasks.components.containers.Landmark.create( + handLandmarkProto.getX(), + handLandmarkProto.getY(), + handLandmarkProto.getZ(), + true)); + } + } + for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) { + List handWorldLandmarks = + new ArrayList<>(); + multiHandWorldLandmarks.add(handWorldLandmarks); + for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) { + handWorldLandmarks.add( + com.google.mediapipe.tasks.components.containers.Landmark.create( + handWorldLandmarkProto.getX(), + handWorldLandmarkProto.getY(), + handWorldLandmarkProto.getZ(), + false)); + } + } + for (ClassificationList handednessProto : handednessesProto) { + List handedness = new ArrayList<>(); + multiHandHandednesses.add(handedness); + for (Classification classification : handednessProto.getClassificationList()) { + handedness.add( + Category.create( + classification.getScore(), + classification.getIndex(), + classification.getLabel(), + classification.getDisplayName())); + } + } + for (ClassificationList gestureProto : gesturesProto) { + List gestures = new ArrayList<>(); + multiHandGestures.add(gestures); + for (Classification classification : gestureProto.getClassificationList()) { + gestures.add( + Category.create( + classification.getScore(), + classification.getIndex(), + classification.getLabel(), + classification.getDisplayName())); + } + } + return new AutoValue_GestureRecognitionResult( + timestampMs, + Collections.unmodifiableList(multiHandLandmarks), + Collections.unmodifiableList(multiHandWorldLandmarks), + Collections.unmodifiableList(multiHandHandednesses), + Collections.unmodifiableList(multiHandGestures)); + } + + @Override + public abstract long timestampMs(); + + /** Hand landmarks of detected hands. */ + public abstract List> landmarks(); + + /** Hand landmarks in world coordniates of detected hands. */ + public abstract List> + worldLandmarks(); + + /** Handedness of detected hands. */ + public abstract List> handednesses(); + + /** Recognized hand gestures of detected hands */ + public abstract List> gestures(); +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java index 9a0c7e8f6..108c021ea 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetectionResult.java @@ -38,7 +38,8 @@ public abstract class ObjectDetectionResult implements TaskResult { * Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf * messages. * - * @param detectionList a list of {@link Detection} protobuf messages. + * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages. + * @param timestampMs a timestamp for this result. */ static ObjectDetectionResult create(List detectionList, long timestampMs) { List detections = new ArrayList<>(); diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java index 463ab4c43..b64992d3e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/ObjectDetector.java @@ -155,7 +155,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}. * * @param context an Android {@link Context}. - * @param detectorOptions a {@link ObjectDetectorOptions} instance. + * @param detectorOptions an {@link ObjectDetectorOptions} instance. * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. */ public static ObjectDetector createFromOptions( @@ -192,7 +192,6 @@ public final class ObjectDetector extends BaseVisionTaskApi { .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM) .build(), handler); - detectorOptions.errorListener().ifPresent(runner::setErrorListener); return new ObjectDetector(runner, detectorOptions.runningMode()); } @@ -204,7 +203,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @param runningMode a mediapipe vision task {@link RunningMode}. */ private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { - super(taskRunner, runningMode); + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); } /** @@ -221,7 +220,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public ObjectDetectionResult detect(Image inputImage) { - return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage); + return (ObjectDetectionResult) processImageData(inputImage); } /** @@ -242,8 +241,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) { - return (ObjectDetectionResult) - processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); + return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs); } /** @@ -265,7 +263,7 @@ public final class ObjectDetector extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public void detectAsync(Image inputImage, long inputTimestampMs) { - sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); + sendLiveStreamData(inputImage, inputTimestampMs); } /** Options for setting up an {@link ObjectDetector}. */ @@ -275,12 +273,12 @@ public final class ObjectDetector extends BaseVisionTaskApi { /** Builder for {@link ObjectDetectorOptions}. */ @AutoValue.Builder public abstract static class Builder { - /** Sets the base options for the object detector task. */ + /** Sets the {@link BaseOptions} for the object detector task. */ public abstract Builder setBaseOptions(BaseOptions value); /** - * Sets the running mode for the object detector task. Default to the image mode. Object - * detector has three modes: + * Sets the {@link RunningMode} for the object detector task. Default to the image mode. + * Object detector has three modes: * *

    *
  • IMAGE: The mode for detecting objects on single image inputs. @@ -293,8 +291,8 @@ public final class ObjectDetector extends BaseVisionTaskApi { public abstract Builder setRunningMode(RunningMode value); /** - * Sets the locale to use for display names specified through the TFLite Model Metadata, if - * any. Defaults to English. + * Sets the optional locale to use for display names specified through the TFLite Model + * Metadata, if any. */ public abstract Builder setDisplayNamesLocale(String value); @@ -331,12 +329,12 @@ public final class ObjectDetector extends BaseVisionTaskApi { public abstract Builder setCategoryDenylist(List value); /** - * Sets the result listener to receive the detection results asynchronously when the object - * detector is in the live stream mode. + * Sets the {@link ResultListener} to receive the detection results asynchronously when the + * object detector is in the live stream mode. */ public abstract Builder setResultListener(ResultListener value); - /** Sets an optional error listener. */ + /** Sets an optional {@link ErrorListener}}. */ public abstract Builder setErrorListener(ErrorListener value); abstract ObjectDetectorOptions autoBuild(); diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml index 3e5e81920..19bd638e9 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/objectdetector/AndroidManifest.xml @@ -11,7 +11,7 @@ android:targetSdkVersion="30" /> diff --git a/mediapipe/tasks/python/test/test_util.py b/mediapipe/tasks/python/test/test_util.py deleted file mode 100644 index 531a18f7a..000000000 --- a/mediapipe/tasks/python/test/test_util.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test util for MediaPipe Tasks.""" - -import os - -from absl import flags - -from mediapipe.python._framework_bindings import image as image_module -from mediapipe.python._framework_bindings import image_frame as image_frame_module - -FLAGS = flags.FLAGS -_Image = image_module.Image -_ImageFormat = image_frame_module.ImageFormat -_RGB_CHANNELS = 3 - - -def test_srcdir(): - """Returns the path where to look for test data files.""" - if "test_srcdir" in flags.FLAGS: - return flags.FLAGS["test_srcdir"].value - elif "TEST_SRCDIR" in os.environ: - return os.environ["TEST_SRCDIR"] - else: - raise RuntimeError("Missing TEST_SRCDIR environment.") - - -def get_test_data_path(file_or_dirname: str) -> str: - """Returns full test data path.""" - for (directory, subdirs, files) in os.walk(test_srcdir()): - for f in subdirs + files: - if f.endswith(file_or_dirname): - return os.path.join(directory, f) - raise ValueError("No %s in test directory" % file_or_dirname) diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index 95b6bf867..d5cebd94b 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -119,7 +119,7 @@ class ObjectDetectorTest(parameterized.TestCase): with self.assertRaisesRegex( ValueError, r"ExternalFile must specify at least one of 'file_content', " - r"'file_name' or 'file_descriptor_meta'."): + r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."): base_options = _BaseOptions(model_asset_path='') options = _ObjectDetectorOptions(base_options=base_options) _ObjectDetector.create_from_options(options) diff --git a/mediapipe/tasks/testdata/text/BUILD b/mediapipe/tasks/testdata/text/BUILD index 80f1163db..1b569127b 100644 --- a/mediapipe/tasks/testdata/text/BUILD +++ b/mediapipe/tasks/testdata/text/BUILD @@ -27,6 +27,7 @@ mediapipe_files(srcs = [ "albert_with_metadata.tflite", "bert_text_classifier.tflite", "mobilebert_with_metadata.tflite", + "test_model_text_classifier_bool_output.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite", ]) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 5eda42601..290b29016 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -47,6 +47,7 @@ mediapipe_files(srcs = [ "mozart_square.jpg", "multi_objects.jpg", "palm_detection_full.tflite", + "pointing_up.jpg", "right_hands.jpg", "segmentation_golden_rotation0.png", "segmentation_input_rotation0.jpg", @@ -54,6 +55,7 @@ mediapipe_files(srcs = [ "selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_144_256_3.tflite", "selfie_segm_144_256_3_expected_mask.jpg", + "thumb_up.jpg", ]) exports_files( @@ -79,11 +81,13 @@ filegroup( "left_hands.jpg", "mozart_square.jpg", "multi_objects.jpg", + "pointing_up.jpg", "right_hands.jpg", "segmentation_golden_rotation0.png", "segmentation_input_rotation0.jpg", "selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_144_256_3_expected_mask.jpg", + "thumb_up.jpg", ], visibility = [ "//mediapipe/python:__subpackages__", diff --git a/mediapipe/tasks/testdata/vision/pointing_up_landmarks.pbtxt b/mediapipe/tasks/testdata/vision/pointing_up_landmarks.pbtxt index fdd8b9c8d..05917af3e 100644 --- a/mediapipe/tasks/testdata/vision/pointing_up_landmarks.pbtxt +++ b/mediapipe/tasks/testdata/vision/pointing_up_landmarks.pbtxt @@ -8,216 +8,216 @@ classifications { landmarks { landmark { - x: 0.4749803 - y: 0.76872 - z: 9.286178e-08 + x: 0.47923622 + y: 0.7426044 + z: 2.3221878e-07 } landmark { - x: 0.5466898 - y: 0.6706463 - z: -0.03454024 + x: 0.5403745 + y: 0.66178805 + z: -0.044572093 } landmark { - x: 0.5890165 - y: 0.5604909 - z: -0.055142127 + x: 0.5774534 + y: 0.5608346 + z: -0.07581605 } landmark { - x: 0.52780133 - y: 0.49855334 - z: -0.07846409 + x: 0.52648556 + y: 0.50247055 + z: -0.105467044 } landmark { - x: 0.44487286 - y: 0.49801928 - z: -0.10188004 + x: 0.44289914 + y: 0.49489295 + z: -0.13422011 } landmark { - x: 0.47572923 - y: 0.44477755 - z: -0.028345175 + x: 0.4728853 + y: 0.43925008 + z: -0.058122505 } landmark { - x: 0.48013464 - y: 0.32467923 - z: -0.06513901 + x: 0.4803168 + y: 0.32889345 + z: -0.101187326 } landmark { - x: 0.48351905 - y: 0.25804192 - z: -0.086756624 + x: 0.48436823 + y: 0.25876504 + z: -0.12840955 } landmark { - x: 0.47760454 - y: 0.19289327 - z: -0.10468461 + x: 0.47388697 + y: 0.19592366 + z: -0.15085006 } landmark { - x: 0.3993108 - y: 0.47566867 - z: -0.040357687 + x: 0.39129356 + y: 0.47211456 + z: -0.06835801 } landmark { - x: 0.42361537 - y: 0.42491958 - z: -0.103545874 + x: 0.41798547 + y: 0.42218646 + z: -0.12954563 } landmark { - x: 0.46059948 - y: 0.51723665 - z: -0.1214961 + x: 0.45758423 + y: 0.5232461 + z: -0.14131334 } landmark { - x: 0.4580545 - y: 0.55640894 - z: -0.12272568 + x: 0.45100626 + y: 0.5554065 + z: -0.13883406 } landmark { - x: 0.34109607 - y: 0.5184511 - z: -0.056422118 + x: 0.33133638 + y: 0.51777464 + z: -0.08227023 } landmark { - x: 0.36177525 - y: 0.48427337 - z: -0.12584248 + x: 0.35698116 + y: 0.48688585 + z: -0.14713185 } landmark { - x: 0.40706652 - y: 0.5700621 - z: -0.11658718 + x: 0.40754414 + y: 0.57370347 + z: -0.12981415 } landmark { - x: 0.40535083 - y: 0.6000496 - z: -0.09520916 + x: 0.40011865 + y: 0.5930706 + z: -0.10554546 } landmark { - x: 0.2872031 - y: 0.57303333 - z: -0.074813806 + x: 0.2783401 + y: 0.5735568 + z: -0.09971398 } landmark { - x: 0.30961618 - y: 0.533245 - z: -0.114366606 + x: 0.30884498 + y: 0.5394487 + z: -0.14033116 } landmark { - x: 0.35510173 - y: 0.5838698 - z: -0.096521005 + x: 0.35470563 + y: 0.5917965 + z: -0.11820527 } landmark { - x: 0.36053744 - y: 0.608682 - z: -0.07574715 + x: 0.34865493 + y: 0.61057556 + z: -0.09509217 } } world_landmarks { landmark { - x: 0.018890835 - y: 0.09005852 - z: 0.031907097 + x: 0.016918864 + y: 0.08634466 + z: 0.035783045 } landmark { - x: 0.04198891 - y: 0.061256267 - z: 0.017695501 + x: 0.04193685 + y: 0.056667875 + z: 0.019453367 } landmark { - x: 0.05044507 - y: 0.033841074 - z: 0.0015051212 + x: 0.050382353 + y: 0.031786427 + z: 0.0023380776 } landmark { - x: 0.039822325 - y: 0.0073827556 - z: -0.02168335 + x: 0.043284662 + y: 0.008976387 + z: -0.02496663 } landmark { - x: 0.012921701 - y: 0.0025111444 - z: -0.033813436 + x: 0.016010094 + y: 0.004991216 + z: -0.036876947 } landmark { - x: 0.023851154 - y: -0.011495698 - z: 0.0066048754 + x: 0.02450771 + y: -0.013496464 + z: 0.0041254223 } landmark { - x: 0.023206754 - y: -0.042496294 - z: -0.0026847485 + x: 0.024783865 + y: -0.041331705 + z: -0.0028748964 } landmark { - x: 0.02298078 - y: -0.062678955 - z: -0.013068148 + x: 0.025917178 + y: -0.06191107 + z: -0.010242647 } landmark { - x: 0.021972645 - y: -0.08151748 - z: -0.03677687 + x: 0.023101516 + y: -0.07967696 + z: -0.03152665 } landmark { - x: -0.00016964211 - y: -0.005549716 - z: 0.0058569373 + x: 0.0006629339 + y: -0.0060150283 + z: 0.004906766 } landmark { - x: 0.0075052455 - y: -0.020031122 - z: -0.027775772 + x: 0.0077093104 + y: -0.017035034 + z: -0.029702934 } landmark { - x: 0.017835317 - y: 0.004899453 - z: -0.037390795 + x: 0.017517095 + y: 0.008997183 + z: -0.03692814 } landmark { - x: 0.016913192 - y: 0.018281722 - z: -0.019302163 + x: 0.0145079205 + y: 0.017461296 + z: -0.011290487 } landmark { - x: -0.018799124 - y: 0.0053577404 - z: -0.0040608873 + x: -0.018095909 + y: 0.006112392 + z: -0.0027157406 } landmark { - x: -0.00747582 - y: 0.0019600953 - z: -0.034023333 + x: -0.010212201 + y: 0.0052777785 + z: -0.034659054 } landmark { - x: 0.0035368819 - y: 0.025736088 - z: -0.03452471 + x: 0.0043836404 + y: 0.028383566 + z: -0.03296758 } landmark { - x: 0.0080153765 - y: 0.039885145 - z: -0.013341276 + x: 0.003886811 + y: 0.036054 + z: -0.0074628904 } landmark { - x: -0.029628165 - y: 0.028607829 - z: -0.011377414 + x: -0.03178849 + y: 0.029854178 + z: -0.008874044 } landmark { - x: -0.023356002 - y: 0.017514031 - z: -0.029408533 + x: -0.02403016 + y: 0.021497255 + z: -0.027618393 } landmark { - x: -0.008503268 - y: 0.027560957 - z: -0.035641473 + x: -0.008522437 + y: 0.031886857 + z: -0.032367583 } landmark { - x: -0.0070180474 - y: 0.039056484 - z: -0.023629948 + x: -0.012865841 + y: 0.038687646 + z: -0.017172804 } } diff --git a/mediapipe/tasks/testdata/vision/thumb_up_landmarks.pbtxt b/mediapipe/tasks/testdata/vision/thumb_up_landmarks.pbtxt index 00b47a3da..e73a69d31 100644 --- a/mediapipe/tasks/testdata/vision/thumb_up_landmarks.pbtxt +++ b/mediapipe/tasks/testdata/vision/thumb_up_landmarks.pbtxt @@ -8,216 +8,216 @@ classifications { landmarks { landmark { - x: 0.6065784 - y: 0.7356081 - z: -5.2289305e-08 + x: 0.6387502 + y: 0.67134184 + z: -3.4044612e-07 } landmark { - x: 0.6349347 - y: 0.5735343 - z: -0.047243003 + x: 0.634891 + y: 0.53670025 + z: -0.06968865 } landmark { - x: 0.5788341 - y: 0.42688707 - z: -0.036071796 + x: 0.5746676 + y: 0.41283816 + z: -0.09383486 } landmark { - x: 0.51322824 - y: 0.3153786 - z: -0.021018881 + x: 0.49967948 + y: 0.32550922 + z: -0.10799447 } landmark { - x: 0.49179295 - y: 0.25291175 - z: 0.0061425082 + x: 0.47362617 + y: 0.25102285 + z: -0.10590933 } landmark { - x: 0.49944243 - y: 0.45409226 - z: 0.06513325 + x: 0.40749234 + y: 0.47130388 + z: -0.04694611 } landmark { - x: 0.3822241 - y: 0.45645967 - z: 0.045028925 + x: 0.3372087 + y: 0.46742308 + z: -0.0997342 } landmark { - x: 0.4427338 - y: 0.49150866 - z: 0.024395633 + x: 0.4418445 + y: 0.50960016 + z: -0.111206524 } landmark { - x: 0.5015556 - y: 0.4798539 - z: 0.014423937 + x: 0.48056933 + y: 0.5187666 + z: -0.11022365 } landmark { - x: 0.46654877 - y: 0.5420721 - z: 0.08380699 + x: 0.39218128 + y: 0.5495232 + z: -0.028925514 } landmark { - x: 0.3540949 - y: 0.545657 - z: 0.056201216 + x: 0.34047198 + y: 0.55610204 + z: -0.08213869 } landmark { - x: 0.43828446 - y: 0.5723222 - z: 0.03073385 + x: 0.46152583 + y: 0.58310646 + z: -0.08393028 } landmark { - x: 0.4894746 - y: 0.54662794 - z: 0.016284892 + x: 0.47058716 + y: 0.56413835 + z: -0.078857616 } landmark { - x: 0.44287524 - y: 0.6153337 - z: 0.0878331 + x: 0.39237642 + y: 0.61864823 + z: -0.022026168 } landmark { - x: 0.3531985 - y: 0.6305228 - z: 0.048528627 + x: 0.34304678 + y: 0.62800515 + z: -0.08132204 } landmark { - x: 0.42727134 - y: 0.64344436 - z: 0.027383275 + x: 0.45004016 + y: 0.64300805 + z: -0.06211204 } landmark { - x: 0.46999624 - y: 0.61115295 - z: 0.021795912 + x: 0.4640005 + y: 0.6221539 + z: -0.038953774 } landmark { - x: 0.43323213 - y: 0.6734935 - z: 0.087731235 + x: 0.39231628 + y: 0.68187976 + z: -0.020164328 } landmark { - x: 0.3772134 - y: 0.69590896 - z: 0.07259013 + x: 0.35785866 + y: 0.6985842 + z: -0.052247807 } landmark { - x: 0.42301077 - y: 0.70083475 - z: 0.06279105 + x: 0.42698768 + y: 0.69892275 + z: -0.037642766 } landmark { - x: 0.45672464 - y: 0.6844607 - z: 0.059202813 + x: 0.44422707 + y: 0.6876204 + z: -0.02034688 } } world_landmarks { landmark { - x: 0.047059614 - y: 0.04719348 - z: 0.03951376 + x: 0.06753889 + y: 0.031051591 + z: 0.05541924 } landmark { - x: 0.050449535 - y: 0.012183173 - z: 0.016567508 + x: 0.06327636 + y: -0.003913434 + z: 0.02125023 } landmark { - x: 0.04375921 - y: -0.020305036 - z: 0.012189768 + x: 0.05469646 + y: -0.038668767 + z: 0.01118496 } landmark { - x: 0.022525383 - y: -0.04830697 - z: 0.008714083 + x: 0.03557241 + y: -0.06865983 + z: 0.0029562893 } landmark { - x: 0.011789754 - y: -0.06952699 - z: 0.0029319536 + x: 0.019069858 + y: -0.08740239 + z: 0.007222481 } landmark { - x: 0.009532374 - y: -0.019510617 - z: 0.0015609035 + x: 0.0044852756 + y: -0.02772763 + z: -0.004234833 } landmark { - x: -0.007894232 - y: -0.022080563 - z: -0.014592148 + x: -0.0031203926 + y: -0.024173645 + z: -0.033932913 } landmark { - x: -0.002826123 - y: -0.019949362 - z: -0.009392118 + x: 0.0080217365 + y: -0.018939625 + z: -0.032623816 } landmark { - x: 0.009066351 - y: -0.016403511 - z: 0.005516675 + x: 0.025537387 + y: -0.014517117 + z: -0.004398854 } landmark { - x: -0.0031000748 - y: -0.003971943 - z: 0.004851345 + x: -0.004470923 + y: -0.0040212176 + z: 0.0025033879 } landmark { - x: -0.016852753 - y: -0.009905987 - z: -0.016275175 + x: -0.010845158 + y: -0.0031857258 + z: -0.036282137 } landmark { - x: -0.006703893 - y: -0.0026965735 - z: -0.015606856 + x: 0.016729971 + y: 0.0028876318 + z: -0.036264844 } landmark { - x: 0.007890566 - y: -0.010418876 - z: 0.0050479355 + x: 0.019928008 + y: -0.0032422952 + z: 0.004380459 } landmark { - x: -0.007842411 - y: 0.011552694 - z: -0.0005755241 + x: -0.005686749 + y: 0.017101247 + z: 0.0036791638 } landmark { - x: -0.021125216 - y: 0.009268615 - z: -0.017993882 + x: -0.010514952 + y: 0.017355483 + z: -0.02882688 } landmark { - x: -0.006585305 - y: 0.013378072 - z: -0.01709412 + x: 0.014503509 + y: 0.019414417 + z: -0.026207235 } landmark { - x: 0.008140431 - y: 0.008364402 - z: -0.0051898304 + x: 0.0211232 + y: 0.014327417 + z: 0.0011467658 } landmark { - x: -0.01082343 - y: 0.03213215 - z: -0.00069864903 + x: 0.0011399705 + y: 0.043651186 + z: 0.0068390737 } landmark { - x: -0.0199164 - y: 0.028296603 - z: -0.01447433 + x: -0.010388309 + y: 0.03904784 + z: -0.015677728 } landmark { - x: -0.00960456 - y: 0.026734762 - z: -0.019243335 + x: 0.006957108 + y: 0.03613425 + z: -0.028704688 } landmark { - x: 0.0040425956 - y: 0.025051914 - z: -0.014775545 + x: 0.012793289 + y: 0.03930679 + z: -0.012465539 } } diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 24ceba639..acf05c2fc 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -432,8 +432,8 @@ def external_files(): http_file( name = "com_google_mediapipe_pointing_up_landmarks_pbtxt", - sha256 = "1255b6ba17b4ef7a9b3ce92c0a139e74fbcec272dc251b049b2f06732f9fed83", - urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1662650664573638"], + sha256 = "a3cd7f088a9e997dbb8f00d91dbf3faaacbdb262c8f2fde3c07a9d0656488065", + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1665174976408451"], ) http_file( @@ -562,6 +562,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"], ) + http_file( + name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite", + sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310", + urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"], + ) + http_file( name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite", sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f", @@ -588,8 +594,8 @@ def external_files(): http_file( name = "com_google_mediapipe_thumb_up_landmarks_pbtxt", - sha256 = "bf1913df6ac7cc14b492c10411c827832839985c057b112789e04ce7c1fdd0fa", - urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1662650669387278"], + sha256 = "b129ae0536be4e25d6cdee74aabe9dedf1bcfe87430a40b68be4079db3a4d926", + urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1665174979747784"], ) http_file(