Merge branch 'google:master' into image-classification-python-impl

This commit is contained in:
Kinar R 2022-10-10 20:46:18 +05:30 committed by GitHub
commit a29035b91e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
77 changed files with 3079 additions and 580 deletions

View File

@ -320,6 +320,8 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
], ],
) )

View File

@ -21,8 +21,10 @@
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe { namespace mediapipe {
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
} }
template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>();
tflite::DynamicBuffer dynamic_buffer;
dynamic_buffer.AddString(input_tensor_buffer,
input_tensor.shape().num_elements());
dynamic_buffer.WriteToTensorAsVector(
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
}
template <typename T> template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
int output_tensor_index, int output_tensor_index,
@ -87,12 +102,12 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
break; break;
} }
case TfLiteType::kTfLiteUInt8: { case TfLiteType::kTfLiteUInt8: {
CopyTensorBufferToInterpreter<uint8>(input_tensors[i], CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteInt8: { case TfLiteType::kTfLiteInt8: {
CopyTensorBufferToInterpreter<int8>(input_tensors[i], CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteString: {
CopyTensorBufferToInterpreter<char>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteBool:
// No current use-case for copying MediaPipe Tensors with bool type to
// TfLiteTensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type)); absl::StrCat("Unsupported input tensor type:", input_tensor_type));
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i, CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
&output_tensors.back()); &output_tensors.back());
break; break;
case TfLiteType::kTfLiteBool:
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
Tensor::QuantizationParameters{1.0f, 0});
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteString:
// No current use-case for copying TfLiteTensors with string type to
// MediaPipe Tensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported output tensor type:", absl::StrCat("Unsupported output tensor type:",

View File

@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
case Tensor::ElementType::kInt8: case Tensor::ElementType::kInt8:
Dequantize<int8>(input_tensor, &output_tensors->back()); Dequantize<int8>(input_tensor, &output_tensors->back());
break; break;
case Tensor::ElementType::kBool:
Dequantize<bool>(input_tensor, &output_tensors->back());
break;
default: default:
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"Unsupported input tensor type: ", input_tensor.element_type())); "Unsupported input tensor type: ", input_tensor.element_type()));

View File

@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
ValidateResult(GetOutput(), {-1.007874, 0, 1}); ValidateResult(GetOutput(), {-1.007874, 0, 1});
} }
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
std::vector<bool> tensor = {true, false, true};
PushTensor(Tensor::ElementType::kBool, tensor,
Tensor::QuantizationParameters{1.0f, 0});
MP_ASSERT_OK(runner_.Run());
ValidateResult(GetOutput(), {1, 0, 1});
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -1685,10 +1685,3 @@ cc_test(
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -112,6 +112,14 @@ class MultiPort : public Single {
std::vector<std::unique_ptr<Base>>& vec_; std::vector<std::unique_ptr<Base>>& vec_;
}; };
namespace internal_builder {
template <typename T, typename U>
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>>;
} // namespace internal_builder
// These classes wrap references to the underlying source/destination // These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API. // endpoints, adding type information and the user-visible API.
template <bool IsSide, typename T = internal::Generic> template <bool IsSide, typename T = internal::Generic>
@ -122,13 +130,14 @@ class DestinationImpl {
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec) explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {} : DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
explicit DestinationImpl(DestinationBase* base) : base_(*base) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
DestinationBase& base_;
};
template <bool IsSide, typename T> template <typename U,
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> { std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
public: DestinationImpl<IsSide, U> Cast() {
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort; return DestinationImpl<IsSide, U>(&base_);
}
DestinationBase& base_;
}; };
template <bool IsSide, typename T = internal::Generic> template <bool IsSide, typename T = internal::Generic>
@ -171,12 +180,8 @@ class SourceImpl {
return AddTarget(dest); return AddTarget(dest);
} }
template <typename U> template <typename U,
struct AllowCast std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>> {};
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
SourceImpl<IsSide, U> Cast() { SourceImpl<IsSide, U> Cast() {
return SourceImpl<IsSide, U>(base_); return SourceImpl<IsSide, U>(base_);
} }
@ -186,12 +191,6 @@ class SourceImpl {
SourceBase* base_; SourceBase* base_;
}; };
template <bool IsSide, typename T>
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
public:
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
};
// A source and a destination correspond to an output/input stream on a node, // A source and a destination correspond to an output/input stream on a node,
// and a side source and side destination correspond to an output/input side // and a side source and side destination correspond to an output/input side
// packet. // packet.
@ -201,20 +200,20 @@ class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
template <typename T = internal::Generic> template <typename T = internal::Generic>
using Source = SourceImpl<false, T>; using Source = SourceImpl<false, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSource = MultiSourceImpl<false, T>; using MultiSource = MultiPort<Source<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using SideSource = SourceImpl<true, T>; using SideSource = SourceImpl<true, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSideSource = MultiSourceImpl<true, T>; using MultiSideSource = MultiPort<SideSource<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using Destination = DestinationImpl<false, T>; using Destination = DestinationImpl<false, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using SideDestination = DestinationImpl<true, T>; using SideDestination = DestinationImpl<true, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiDestination = MultiDestinationImpl<false, T>; using MultiDestination = MultiPort<Destination<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSideDestination = MultiDestinationImpl<true, T>; using MultiSideDestination = MultiPort<SideDestination<T>>;
class NodeBase { class NodeBase {
public: public:

View File

@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>(); node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output.SetName("any_type_output"); any_type_output.SetName("any_type_output");
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
CalculatorGraphConfig expected = CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node { node {
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
output_stream: "ANY_OUTPUT:any_type_output" output_stream: "ANY_OUTPUT:any_type_output"
} }
input_stream: "GRAPH_ANY_INPUT:__stream_0" input_stream: "GRAPH_ANY_INPUT:__stream_0"
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
)pb"); )pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }

View File

@ -334,13 +334,6 @@ mediapipe_register_type(
deps = [":landmark_cc_proto"], deps = [":landmark_cc_proto"],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)
cc_library( cc_library(
name = "image", name = "image",
srcs = ["image.cc"], srcs = ["image.cc"],

View File

@ -33,10 +33,3 @@ mediapipe_proto_library(
srcs = ["rasterization.proto"], srcs = ["rasterization.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -97,8 +97,8 @@ class Tensor {
kUInt8, kUInt8,
kInt8, kInt8,
kInt32, kInt32,
// TODO: Update the inference runner to handle kTfLiteString. kChar,
kChar kBool
}; };
struct Shape { struct Shape {
Shape() = default; Shape() = default;
@ -330,6 +330,8 @@ class Tensor {
return sizeof(int32_t); return sizeof(int32_t);
case ElementType::kChar: case ElementType::kChar:
return sizeof(char); return sizeof(char);
case ElementType::kBool:
return sizeof(bool);
} }
} }
int bytes() const { return shape_.num_elements() * element_size(); } int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4}); Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char)); EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
} }
TEST(Cpu, TestMemoryAllocation) { TEST(Cpu, TestMemoryAllocation) {

View File

@ -150,7 +150,7 @@ cc_library(
name = "executor_util", name = "executor_util",
srcs = ["executor_util.cc"], srcs = ["executor_util.cc"],
hdrs = ["executor_util.h"], hdrs = ["executor_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto",

View File

@ -1050,7 +1050,7 @@ objc_library(
alwayslink = 1, alwayslink = 1,
) )
MIN_IOS_VERSION = "9.0" # For thread_local. MIN_IOS_VERSION = "11.0"
test_suite( test_suite(
name = "ios", name = "ios",

View File

@ -184,7 +184,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
EXPECT_THAT( EXPECT_THAT(
audio_classifier_or.status().message(), audio_classifier_or.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', " HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name' or 'file_descriptor_meta'.")); "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));

View File

@ -65,6 +65,8 @@ enum class MediaPipeTasksStatus {
kFileReadError, kFileReadError,
// I/O error when mmap-ing file. // I/O error when mmap-ing file.
kFileMmapError, kFileMmapError,
// ZIP I/O error when unpacking the zip file.
kFileZipError,
// TensorFlow Lite metadata error codes. // TensorFlow Lite metadata error codes.

View File

@ -0,0 +1,31 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "landmarks_detection",
hdrs = ["landmarks_detection.h"],
)
cc_library(
name = "gesture_recognition_result",
hdrs = ["gesture_recognition_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)

View File

@ -0,0 +1,46 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace containers {
// The gesture recognition result from GestureRecognizer, where each vector
// element represents a single hand detected in the image.
struct GestureRecognitionResult {
// Recognized hand gestures with sorted order such that the winning label is
// the first item in the list.
std::vector<mediapipe::ClassificationList> gestures;
// Classification of handedness.
std::vector<mediapipe::ClassificationList> handedness;
// Detected hand landmarks in normalized image coordinates.
std::vector<mediapipe::NormalizedLandmarkList> hand_landmarks;
// Detected hand landmarks in world coordinates.
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
};
} // namespace containers
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_

View File

@ -0,0 +1,43 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
#include <vector>
// Sturcts holding landmarks related data structure for hand landmarker, pose
// detector, face mesher, etc.
namespace mediapipe::tasks::components::containers {
// x and y are in [0,1] range with origin in top left in input image space.
// If model provides z, z is in the same scale as x. origin is in the center
// of the face.
struct Landmark {
float x;
float y;
float z;
};
// [0, 1] range in input image space
struct Bound {
float left;
float top;
float right;
float bottom;
};
} // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.containers.proto; package mediapipe.tasks.components.containers.proto;
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_outer_classname = "CategoryProto";
// A single classification result. // A single classification result.
message Category { message Category {
// The index of the category in the corresponding label map, usually packed in // The index of the category in the corresponding label map, usually packed in

View File

@ -19,6 +19,9 @@ package mediapipe.tasks.components.containers.proto;
import "mediapipe/tasks/cc/components/containers/proto/category.proto"; import "mediapipe/tasks/cc/components/containers/proto/category.proto";
option java_package = "com.google.mediapipe.tasks.components.container.proto";
option java_outer_classname = "ClassificationsProto";
// List of predicted categories with an optional timestamp. // List of predicted categories with an optional timestamp.
message ClassificationEntry { message ClassificationEntry {
// The array of predicted categories, usually sorted by descending scores, // The array of predicted categories, usually sorted by descending scores,

View File

@ -123,15 +123,17 @@ absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
const auto* tensor = const auto* tensor =
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i)); primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
if (tensor->type() != tflite::TensorType_FLOAT32 && if (tensor->type() != tflite::TensorType_FLOAT32 &&
tensor->type() != tflite::TensorType_UINT8) { tensor->type() != tflite::TensorType_UINT8 &&
tensor->type() != tflite::TensorType_BOOL) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrFormat("Expected output tensor at index %d to have type " absl::StrFormat("Expected output tensor at index %d to have type "
"UINT8 or FLOAT32, found %s instead.", "UINT8 or FLOAT32 or BOOL, found %s instead.",
i, tflite::EnumNameTensorType(tensor->type())), i, tflite::EnumNameTensorType(tensor->type())),
MediaPipeTasksStatus::kInvalidOutputTensorTypeError); MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
} }
if (tensor->type() == tflite::TensorType_UINT8) { if (tensor->type() == tflite::TensorType_UINT8 ||
tensor->type() == tflite::TensorType_BOOL) {
num_quantized_tensors++; num_quantized_tensors++;
} }
} }
@ -282,6 +284,20 @@ absl::Status ConfigureScoreCalibrationIfAny(
return absl::OkStatus(); return absl::OkStatus();
} }
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
}
for (const auto& metadata : *output_tensors_metadata) {
options->add_head_names(metadata->name()->str());
}
}
} // namespace
// Fills in the TensorsToClassificationCalculatorOptions based on the // Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata. // classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator( absl::Status ConfigureTensorsToClassificationCalculator(
@ -333,20 +349,6 @@ absl::Status ConfigureTensorsToClassificationCalculator(
return absl::OkStatus(); return absl::OkStatus();
} }
void ConfigureClassificationAggregationCalculator(
const ModelMetadataExtractor& metadata_extractor,
ClassificationAggregationCalculatorOptions* options) {
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
if (output_tensors_metadata == nullptr) {
return;
}
for (const auto& metadata : *output_tensors_metadata) {
options->add_head_names(metadata->name()->str());
}
}
} // namespace
absl::Status ConfigureClassificationPostprocessingGraph( absl::Status ConfigureClassificationPostprocessingGraph(
const ModelResources& model_resources, const ModelResources& model_resources,
const proto::ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -55,6 +56,16 @@ absl::Status ConfigureClassificationPostprocessingGraph(
const proto::ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,
proto::ClassificationPostprocessingGraphOptions* options); proto::ClassificationPostprocessingGraphOptions* options);
// Utility function to fill in the TensorsToClassificationCalculatorOptions
// based on the classifier options and the (optional) output tensor metadata.
// This is meant to be used by other graphs that may also rely on this
// calculator.
absl::Status ConfigureTensorsToClassificationCalculator(
const proto::ClassifierOptions& options,
const metadata::ModelMetadataExtractor& metadata_extractor,
int tensor_index,
mediapipe::TensorsToClassificationCalculatorOptions* calculator_options);
} // namespace processors } // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks

View File

@ -17,6 +17,9 @@ syntax = "proto2";
package mediapipe.tasks.components.processors.proto; package mediapipe.tasks.components.processors.proto;
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
option java_outer_classname = "ClassifierOptionsProto";
// Shared options used by all classification tasks. // Shared options used by all classification tasks.
message ClassifierOptions { message ClassifierOptions {
// The locale to use for display names specified through the TFLite Model // The locale to use for display names specified through the TFLite Model

View File

@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions {
BERT_PREPROCESSOR = 1; BERT_PREPROCESSOR = 1;
// Used for the RegexPreprocessorCalculator. // Used for the RegexPreprocessorCalculator.
REGEX_PREPROCESSOR = 2; REGEX_PREPROCESSOR = 2;
// Used for the TextToTensorCalculator.
STRING_PREPROCESSOR = 3;
} }
optional PreprocessorType preprocessor_type = 1; optional PreprocessorType preprocessor_type = 1;

View File

@ -65,6 +65,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
return "BertPreprocessorCalculator"; return "BertPreprocessorCalculator";
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
return "RegexPreprocessorCalculator"; return "RegexPreprocessorCalculator";
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
return "TextToTensorCalculator";
} }
} }
@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) {
MediaPipeTasksStatus::kInvalidInputTensorTypeError); MediaPipeTasksStatus::kInvalidInputTensorTypeError);
} }
if (all_string_tensors) { if (all_string_tensors) {
// TODO: Support a TextToTensor calculator for string tensors. return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"String tensors are not supported yet",
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
} }
// Otherwise, all tensors should have type int32 // Otherwise, all tensors should have type int32
@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph(
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
GetPreprocessorType(model_resources)); GetPreprocessorType(model_resources));
options.set_preprocessor_type(preprocessor_type); options.set_preprocessor_type(preprocessor_type);
switch (preprocessor_type) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
break;
}
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
int max_seq_len, int max_seq_len,
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
options.set_max_seq_len(max_seq_len); options.set_max_seq_len(max_seq_len);
}
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
auto& text_preprocessor = graph.AddNode(preprocessor_name); auto& text_preprocessor = graph.AddNode(preprocessor_name);
switch (options.preprocessor_type()) { switch (options.preprocessor_type()) {
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: { case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
break; break;
} }
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {

View File

@ -92,13 +92,26 @@ absl::Status ExternalFileHandler::MapExternalFile() {
#else #else
if (!external_file_.file_content().empty()) { if (!external_file_.file_content().empty()) {
return absl::OkStatus(); return absl::OkStatus();
} else if (external_file_.has_file_pointer_meta()) {
if (external_file_.file_pointer_meta().pointer() == 0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"Need to set the file pointer in external_file.file_pointer_meta.");
}
if (external_file_.file_pointer_meta().length() <= 0) {
return CreateStatusWithPayload(
StatusCode::kInvalidArgument,
"The length of the file in external_file.file_pointer_meta should be "
"positive.");
}
return absl::OkStatus();
} }
if (external_file_.file_name().empty() && if (external_file_.file_name().empty() &&
!external_file_.has_file_descriptor_meta()) { !external_file_.has_file_descriptor_meta()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
StatusCode::kInvalidArgument, StatusCode::kInvalidArgument,
"ExternalFile must specify at least one of 'file_content', 'file_name' " "ExternalFile must specify at least one of 'file_content', "
"or 'file_descriptor_meta'.", "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.",
MediaPipeTasksStatus::kInvalidArgumentError); MediaPipeTasksStatus::kInvalidArgumentError);
} }
// Obtain file descriptor, offset and size. // Obtain file descriptor, offset and size.
@ -196,6 +209,11 @@ absl::Status ExternalFileHandler::MapExternalFile() {
absl::string_view ExternalFileHandler::GetFileContent() { absl::string_view ExternalFileHandler::GetFileContent() {
if (!external_file_.file_content().empty()) { if (!external_file_.file_content().empty()) {
return external_file_.file_content(); return external_file_.file_content();
} else if (external_file_.has_file_pointer_meta()) {
void* ptr =
reinterpret_cast<void*>(external_file_.file_pointer_meta().pointer());
return absl::string_view(static_cast<const char*>(ptr),
external_file_.file_pointer_meta().length());
} else { } else {
return absl::string_view(static_cast<const char*>(buffer_) + return absl::string_view(static_cast<const char*>(buffer_) +
buffer_offset_ - buffer_aligned_offset_, buffer_offset_ - buffer_aligned_offset_,

View File

@ -26,10 +26,11 @@ option java_outer_classname = "ExternalFileProto";
// (1) file contents loaded in `file_content`. // (1) file contents loaded in `file_content`.
// (2) file path in `file_name`. // (2) file path in `file_name`.
// (3) file descriptor through `file_descriptor_meta` as returned by open(2). // (3) file descriptor through `file_descriptor_meta` as returned by open(2).
// (4) file pointer and length in memory through `file_pointer_meta`.
// //
// If more than one field of these fields is provided, they are used in this // If more than one field of these fields is provided, they are used in this
// precedence order. // precedence order.
// Next id: 4 // Next id: 5
message ExternalFile { message ExternalFile {
// The file contents as a byte array. // The file contents as a byte array.
optional bytes file_content = 1; optional bytes file_content = 1;
@ -40,6 +41,13 @@ message ExternalFile {
// The file descriptor to a file opened with open(2), with optional additional // The file descriptor to a file opened with open(2), with optional additional
// offset and length information. // offset and length information.
optional FileDescriptorMeta file_descriptor_meta = 3; optional FileDescriptorMeta file_descriptor_meta = 3;
// The pointer points to location of a file in memory. Use the util method,
// `SetExternalFile` in [1], to configure `file_pointer_meta` from a
// `std::string_view` object.
//
// [1]: mediapipe/tasks/cc/metadata/utils/zip_utils.h
optional FilePointerMeta file_pointer_meta = 4;
} }
// A proto defining file descriptor metadata for mapping file into memory using // A proto defining file descriptor metadata for mapping file into memory using
@ -62,3 +70,14 @@ message FileDescriptorMeta {
// offset of a given asset obtained from AssetFileDescriptor#getStartOffset(). // offset of a given asset obtained from AssetFileDescriptor#getStartOffset().
optional int64 offset = 3; optional int64 offset = 3;
} }
// The pointer points to location of a file in memory. Make sure the file memory
// that it points locates on the same machine and it outlives this
// FilePointerMeta object.
message FilePointerMeta {
// Memory address of the file in decimal.
optional uint64 pointer = 1;
// File length.
optional int64 length = 2;
}

View File

@ -19,8 +19,9 @@ cc_library(
deps = [ deps = [
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/metadata/utils:zip_readonly_mem_file", "//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
@ -29,7 +30,6 @@ cc_library(
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@flatbuffers//:runtime_cc", "@flatbuffers//:runtime_cc",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs", "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
"@zlib//:zlib_minizip",
], ],
) )

View File

@ -17,16 +17,16 @@ limitations under the License.
#include <string> #include <string>
#include "absl/cleanup/cleanup.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "contrib/minizip/ioapi.h"
#include "contrib/minizip/unzip.h"
#include "flatbuffers/flatbuffers.h" #include "flatbuffers/flatbuffers.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
@ -53,72 +53,6 @@ const T* GetItemFromVector(
} }
return src_vector->Get(index); return src_vector->Get(index);
} }
// Wrapper function around calls to unzip to avoid repeating conversion logic
// from error code to Status.
absl::Status UnzipErrorToStatus(int error) {
if (error != UNZ_OK) {
return CreateStatusWithPayload(
StatusCode::kUnknown, "Unable to read associated file in zip archive.",
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
}
return absl::OkStatus();
}
// Stores a file name, position in zip buffer and size.
struct ZipFileInfo {
std::string name;
ZPOS64_T position;
ZPOS64_T size;
};
// Returns the ZipFileInfo corresponding to the current file in the provided
// unzFile object.
absl::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
// Open file in raw mode, as data is expected to be uncompressed.
int method;
MP_RETURN_IF_ERROR(UnzipErrorToStatus(
unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
if (method != Z_NO_COMPRESSION) {
return CreateStatusWithPayload(
StatusCode::kUnknown, "Expected uncompressed zip archive.",
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
}
// Get file info a first time to get filename size.
unz_file_info64 file_info;
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
// Second call to get file name.
auto file_name_size = file_info.size_filename;
char* c_file_name = (char*)malloc(file_name_size);
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
zf, &file_info, c_file_name, file_name_size,
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
std::string file_name = std::string(c_file_name, file_name_size);
free(c_file_name);
// Get position in file.
auto position = unzGetCurrentFileZStreamPos64(zf);
if (position == 0) {
return CreateStatusWithPayload(
StatusCode::kUnknown, "Unable to read file in zip archive.",
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
}
// Close file and return.
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
ZipFileInfo result{};
result.name = file_name;
result.position = position;
result.size = file_info.uncompressed_size;
return result;
}
} // namespace } // namespace
/* static */ /* static */
@ -238,47 +172,15 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
absl::Status ModelMetadataExtractor::ExtractAssociatedFiles( absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
const char* buffer_data, size_t buffer_size) { const char* buffer_data, size_t buffer_size) {
// Create in-memory read-only zip file. auto status =
ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size); ExtractFilesfromZipFile(buffer_data, buffer_size, &associated_files_);
// Open zip. if (!status.ok() &&
unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def()); absl::StrContains(status.message(), "Unable to open zip archive.")) {
if (zf == nullptr) {
// It's OK if it fails: this means there are no associated files with this // It's OK if it fails: this means there are no associated files with this
// model. // model.
return absl::OkStatus(); return absl::OkStatus();
} }
// Get number of files. return status;
unz_global_info global_info;
if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
return CreateStatusWithPayload(
StatusCode::kUnknown, "Unable to get zip archive info.",
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
}
// Browse through files in archive.
if (global_info.number_entry > 0) {
int error = unzGoToFirstFile(zf);
while (error == UNZ_OK) {
ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
// Store result in map.
associated_files_[zip_file_info.name] = absl::string_view(
buffer_data + zip_file_info.position, zip_file_info.size);
error = unzGoToNextFile(zf);
}
if (error != UNZ_END_OF_LIST_OF_FILE) {
return CreateStatusWithPayload(
StatusCode::kUnknown,
"Unable to read associated file in zip archive.",
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
}
}
// Close zip.
if (unzClose(zf) != UNZ_OK) {
return CreateStatusWithPayload(
StatusCode::kUnknown, "Unable to close zip archive.",
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
}
return absl::OkStatus();
} }
absl::StatusOr<absl::string_view> ModelMetadataExtractor::GetAssociatedFile( absl::StatusOr<absl::string_view> ModelMetadataExtractor::GetAssociatedFile(

View File

@ -24,3 +24,20 @@ cc_library(
"@zlib//:zlib_minizip", "@zlib//:zlib_minizip",
], ],
) )
cc_library(
name = "zip_utils",
srcs = ["zip_utils.cc"],
hdrs = ["zip_utils.h"],
deps = [
":zip_readonly_mem_file",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@zlib//:zlib_minizip",
],
)

View File

@ -0,0 +1,175 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "contrib/minizip/ioapi.h"
#include "contrib/minizip/unzip.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h"
namespace mediapipe {
namespace tasks {
namespace metadata {
namespace {
using ::absl::StatusCode;
// Wrapper function around calls to unzip to avoid repeating conversion logic
// from error code to Status.
absl::Status UnzipErrorToStatus(int error) {
if (error != UNZ_OK) {
return CreateStatusWithPayload(StatusCode::kUnknown,
"Unable to read the file in zip archive.",
MediaPipeTasksStatus::kFileZipError);
}
return absl::OkStatus();
}
// Stores a file name, position in zip buffer and size.
struct ZipFileInfo {
std::string name;
ZPOS64_T position;
ZPOS64_T size;
};
// Returns the ZipFileInfo corresponding to the current file in the provided
// unzFile object.
absl::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
// Open file in raw mode, as data is expected to be uncompressed.
int method;
MP_RETURN_IF_ERROR(UnzipErrorToStatus(
unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
absl::Cleanup unzipper_closer = [zf]() {
auto status = UnzipErrorToStatus(unzCloseCurrentFile(zf));
if (!status.ok()) {
LOG(ERROR) << "Failed to close the current zip file: " << status;
}
};
if (method != Z_NO_COMPRESSION) {
return CreateStatusWithPayload(StatusCode::kUnknown,
"Expected uncompressed zip archive.",
MediaPipeTasksStatus::kFileZipError);
}
// Get file info a first time to get filename size.
unz_file_info64 file_info;
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
// Second call to get file name.
auto file_name_size = file_info.size_filename;
char* c_file_name = (char*)malloc(file_name_size);
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
zf, &file_info, c_file_name, file_name_size,
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
std::string file_name = std::string(c_file_name, file_name_size);
free(c_file_name);
// Get position in file.
auto position = unzGetCurrentFileZStreamPos64(zf);
if (position == 0) {
return CreateStatusWithPayload(StatusCode::kUnknown,
"Unable to read file in zip archive.",
MediaPipeTasksStatus::kFileZipError);
}
// Perform the cleanup manually for error propagation.
std::move(unzipper_closer).Cancel();
// Close file and return.
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
ZipFileInfo result{};
result.name = file_name;
result.position = position;
result.size = file_info.uncompressed_size;
return result;
}
} // namespace
absl::Status ExtractFilesfromZipFile(
const char* buffer_data, const size_t buffer_size,
absl::flat_hash_map<std::string, absl::string_view>* files) {
// Create in-memory read-only zip file.
ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
// Open zip.
unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def());
if (zf == nullptr) {
return CreateStatusWithPayload(StatusCode::kUnknown,
"Unable to open zip archive.",
MediaPipeTasksStatus::kFileZipError);
}
absl::Cleanup unzipper_closer = [zf]() {
if (unzClose(zf) != UNZ_OK) {
LOG(ERROR) << "Unable to close zip archive.";
}
};
// Get number of files.
unz_global_info global_info;
if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
return CreateStatusWithPayload(StatusCode::kUnknown,
"Unable to get zip archive info.",
MediaPipeTasksStatus::kFileZipError);
}
// Browse through files in archive.
if (global_info.number_entry > 0) {
int error = unzGoToFirstFile(zf);
while (error == UNZ_OK) {
ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
// Store result in map.
(*files)[zip_file_info.name] = absl::string_view(
buffer_data + zip_file_info.position, zip_file_info.size);
error = unzGoToNextFile(zf);
}
if (error != UNZ_END_OF_LIST_OF_FILE) {
return CreateStatusWithPayload(
StatusCode::kUnknown,
"Unable to read associated file in zip archive.",
MediaPipeTasksStatus::kFileZipError);
}
}
// Perform the cleanup manually for error propagation.
std::move(unzipper_closer).Cancel();
// Close zip.
if (unzClose(zf) != UNZ_OK) {
return CreateStatusWithPayload(StatusCode::kUnknown,
"Unable to close zip archive.",
MediaPipeTasksStatus::kFileZipError);
}
return absl::OkStatus();
}
void SetExternalFile(const std::string_view& file_content,
core::proto::ExternalFile* model_file) {
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
}
} // namespace metadata
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,47 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_
#define MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_
#include <string>
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
namespace mediapipe {
namespace tasks {
namespace metadata {
// Extract files from the zip file.
// Input: Pointer and length of the zip file in memory.
// Outputs: A map with the filename as key and a pointer to the file contents
// as value. The file contents returned by this function are only guaranteed to
// stay valid while buffer_data is alive.
absl::Status ExtractFilesfromZipFile(
const char* buffer_data, const size_t buffer_size,
absl::flat_hash_map<std::string, absl::string_view>* files);
// Set file_pointer_meta in ExternalFile which is the pointer points to location
// of a file in memory by file_content.
void SetExternalFile(const std::string_view& file_content,
core::proto::ExternalFile* model_file);
} // namespace metadata
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_

View File

@ -44,8 +44,12 @@ cc_library(
name = "hand_gesture_recognizer_graph", name = "hand_gesture_recognizer_graph",
srcs = ["hand_gesture_recognizer_graph.cc"], srcs = ["hand_gesture_recognizer_graph.cc"],
deps = [ deps = [
"//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:concatenate_vector_calculator", "//mediapipe/calculators/core:concatenate_vector_calculator",
"//mediapipe/calculators/core:end_loop_calculator",
"//mediapipe/calculators/core:get_vector_item_calculator",
"//mediapipe/calculators/tensor:tensor_converter_calculator", "//mediapipe/calculators/tensor:tensor_converter_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
@ -55,7 +59,6 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
@ -67,10 +70,81 @@ cc_library(
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "gesture_recognizer_graph",
srcs = ["gesture_recognizer_graph.cc"],
deps = [
":hand_gesture_recognizer_graph",
"//mediapipe/calculators/core:vector_indices_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(
name = "gesture_recognizer",
srcs = ["gesture_recognizer.cc"],
hdrs = ["gesture_recognizer.h"],
deps = [
":gesture_recognizer_graph",
":hand_gesture_recognizer_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)

View File

@ -0,0 +1,282 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h"
#include <memory>
#include <type_traits>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
namespace {
using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision::
gesture_recognizer::proto::GestureRecognizerGraphOptions;
using ::mediapipe::tasks::components::containers::GestureRecognitionResult;
constexpr char kHandGestureSubgraphTypeName[] =
"mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out";
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandGesturesStreamName[] = "hand_gestures";
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kHandednessStreamName[] = "handedness";
constexpr char kHandLandmarksTag[] = "LANDMARKS";
constexpr char kHandLandmarksStreamName[] = "landmarks";
constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
constexpr int kMicroSecondsPerMilliSecond = 1000;
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running
// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the
// number of frames in flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<GestureRecognizerGraphOptionsProto> options,
bool enable_flow_limiting) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName);
subgraph.GetOptions<GestureRecognizerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName);
subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >>
graph.Out(kHandGesturesTag);
subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >>
graph.Out(kHandednessTag);
subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >>
graph.Out(kHandLandmarksTag);
subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >>
graph.Out(kHandWorldLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(graph, subgraph, {kImageTag},
kHandGesturesTag);
}
graph.In(kImageTag) >> subgraph.In(kImageTag);
return graph.GetConfig();
}
// Converts the user-facing GestureRecognizerOptions struct to the internal
// GestureRecognizerGraphOptions proto.
std::unique_ptr<GestureRecognizerGraphOptionsProto>
ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
auto options_proto = std::make_unique<GestureRecognizerGraphOptionsProto>();
bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE;
// TODO remove these workarounds for base options of subgraphs.
// Configure hand detector options.
auto base_options_proto_for_hand_detector =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_hand_detector)));
base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode);
auto* hand_detector_graph_options =
options_proto->mutable_hand_landmarker_graph_options()
->mutable_hand_detector_graph_options();
hand_detector_graph_options->mutable_base_options()->Swap(
base_options_proto_for_hand_detector.get());
hand_detector_graph_options->set_num_hands(options->num_hands);
hand_detector_graph_options->set_min_detection_confidence(
options->min_hand_detection_confidence);
// Configure hand landmark detector options.
auto base_options_proto_for_hand_landmarker =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_hand_landmarker)));
base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode);
auto* hand_landmarks_detector_graph_options =
options_proto->mutable_hand_landmarker_graph_options()
->mutable_hand_landmarks_detector_graph_options();
hand_landmarks_detector_graph_options->mutable_base_options()->Swap(
base_options_proto_for_hand_landmarker.get());
hand_landmarks_detector_graph_options->set_min_detection_confidence(
options->min_hand_presence_confidence);
auto* hand_landmarker_graph_options =
options_proto->mutable_hand_landmarker_graph_options();
hand_landmarker_graph_options->set_min_tracking_confidence(
options->min_tracking_confidence);
// Configure hand gesture recognizer options.
auto base_options_proto_for_gesture_recognizer =
std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(
&(options->base_options_for_gesture_recognizer)));
base_options_proto_for_gesture_recognizer->set_use_stream_mode(
use_stream_mode);
auto* hand_gesture_recognizer_graph_options =
options_proto->mutable_hand_gesture_recognizer_graph_options();
hand_gesture_recognizer_graph_options->mutable_base_options()->Swap(
base_options_proto_for_gesture_recognizer.get());
if (options->min_gesture_confidence >= 0) {
hand_gesture_recognizer_graph_options->mutable_classifier_options()
->set_score_threshold(options->min_gesture_confidence);
}
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
std::unique_ptr<GestureRecognizerOptions> options) {
auto options_proto = ConvertGestureRecognizerGraphOptionsProto(options.get());
tasks::core::PacketsCallback packets_callback = nullptr;
if (options->result_callback) {
auto result_callback = options->result_callback;
packets_callback = [=](absl::StatusOr<tasks::core::PacketMap>
status_or_packets) {
if (!status_or_packets.ok()) {
Image image;
result_callback(status_or_packets.status(), image,
Timestamp::Unset().Value());
return;
}
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
return;
}
Packet gesture_packet =
status_or_packets.value()[kHandGesturesStreamName];
Packet handedness_packet =
status_or_packets.value()[kHandednessStreamName];
Packet hand_landmarks_packet =
status_or_packets.value()[kHandLandmarksStreamName];
Packet hand_world_landmarks_packet =
status_or_packets.value()[kHandWorldLandmarksStreamName];
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
result_callback(
{{gesture_packet.Get<std::vector<ClassificationList>>(),
handedness_packet.Get<std::vector<ClassificationList>>(),
hand_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
hand_world_landmarks_packet.Get<std::vector<LandmarkList>>()}},
image_packet.Get<Image>(),
gesture_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
};
}
return core::VisionTaskApiFactory::Create<GestureRecognizer,
GestureRecognizerGraphOptionsProto>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
}
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
mediapipe::Image image) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(auto output_packets,
ProcessImageData({{kImageInStreamName,
MakePacket<Image>(std::move(image))}}));
return {
{/* gestures= */ {output_packets[kHandGesturesStreamName]
.Get<std::vector<ClassificationList>>()},
/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}},
};
}
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
mediapipe::Image image, int64 timestamp_ms) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
ASSIGN_OR_RETURN(
auto output_packets,
ProcessVideoData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
return {
{/* gestures= */ {output_packets[kHandGesturesStreamName]
.Get<std::vector<ClassificationList>>()},
/* handedness= */
{output_packets[kHandednessStreamName]
.Get<std::vector<mediapipe::ClassificationList>>()},
/* hand_landmarks= */
{output_packets[kHandLandmarksStreamName]
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
/* hand_world_landmarks */
{output_packets[kHandWorldLandmarksStreamName]
.Get<std::vector<mediapipe::LandmarkList>>()}},
};
}
absl::Status GestureRecognizer::RecognizeAsync(mediapipe::Image image,
int64 timestamp_ms) {
if (image.UsesGpu()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
}
return SendLiveStreamData(
{{kImageInStreamName,
MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,172 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_
#include <memory>
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
struct GestureRecognizerOptions {
// Base options for configuring Task library, such as specifying the TfLite
// model file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// TODO: remove these. Temporary solutions before bundle asset is
// ready.
tasks::core::BaseOptions base_options_for_hand_landmarker;
tasks::core::BaseOptions base_options_for_hand_detector;
tasks::core::BaseOptions base_options_for_gesture_recognizer;
// The running mode of the task. Default to the image mode.
// GestureRecognizer has three running modes:
// 1) The image mode for recognizing hand gestures on single image inputs.
// 2) The video mode for recognizing hand gestures on the decoded frames of a
// video.
// 3) The live stream mode for recognizing hand gestures on the live stream of
// input data, such as from camera. In this mode, the "result_callback"
// below must be specified to receive the detection results asynchronously.
core::RunningMode running_mode = core::RunningMode::IMAGE;
// The maximum number of hands can be detected by the GestureRecognizer.
int num_hands = 1;
// The minimum confidence score for the hand detection to be considered
// successfully.
float min_hand_detection_confidence = 0.5;
// The minimum confidence score of hand presence score in the hand landmark
// detection.
float min_hand_presence_confidence = 0.5;
// The minimum confidence score for the hand tracking to be considered
// successfully.
float min_tracking_confidence = 0.5;
// The minimum confidence score for the gestures to be considered
// successfully. If < 0, the gesture confidence thresholds in the model
// metadata are used.
// TODO Note this option is subject to change, after scoring
// merging calculator is implemented.
float min_gesture_confidence = -1;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(
absl::StatusOr<components::containers::GestureRecognitionResult>,
const Image&, int64)>
result_callback = nullptr;
};
// Performs hand gesture recognition on the given image.
//
// TODO add the link to DevSite.
// This API expects expects a pre-trained hand gesture model asset bundle, or a
// custom one created using Model Maker. See <link to the DevSite documentation
// page>.
//
// Inputs:
// Image
// - The image that gesture recognition runs on.
// Outputs:
// GestureRecognitionResult
// - The hand gesture recognition results.
class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
public:
using BaseVisionTaskApi::BaseVisionTaskApi;
// Creates a GestureRecognizer from a GestureRecognizerhOptions to process
// image data or streaming data. Gesture recognizer can be created with one of
// the following three running modes:
// 1) Image mode for recognizing gestures on single image inputs.
// Users provide mediapipe::Image to the `Recognize` method, and will
// receive the recognized hand gesture results as the return value.
// 2) Video mode for recognizing gestures on the decoded frames of a video.
// 3) Live stream mode for recognizing gestures on the live stream of the
// input data, such as from camera. Users call `RecognizeAsync` to push the
// image data into the GestureRecognizer, the recognized results along with
// the input timestamp and the image that gesture recognizer runs on will
// be available in the result callback when the gesture recognizer finishes
// the work.
static absl::StatusOr<std::unique_ptr<GestureRecognizer>> Create(
std::unique_ptr<GestureRecognizerOptions> options);
// Performs hand gesture recognition on the given image.
// Only use this method when the GestureRecognizer is created with the image
// running mode.
//
// image - mediapipe::Image
// Image to perform hand gesture recognition on.
//
// The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented.
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
Image image);
// Performs gesture recognition on the provided video frame.
// Only use this method when the GestureRecognizer is created with the video
// running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<components::containers::GestureRecognitionResult>
RecognizeForVideo(Image image, int64 timestamp_ms);
// Sends live image data to perform gesture recognition, and the results will
// be available via the "result_callback" provided in the
// GestureRecognizerOptions. Only use this method when the GestureRecognizer
// is created with the live stream running mode.
//
// The image can be of any size with format RGB or RGBA. It's required to
// provide a timestamp (in milliseconds) to indicate when the input image is
// sent to the gesture recognizer. The input timestamps must be monotonically
// increasing.
//
// The "result_callback" provides
// - A vector of GestureRecognitionResult, each is the recognized results
// for a input frame.
// - The const reference to the corresponding input image that the gesture
// recognizer runs on. Note that the const reference to the image will no
// longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds.
absl::Status RecognizeAsync(Image image, int64 timestamp_ms);
// Shuts down the GestureRecognizer when all works are done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_

View File

@ -0,0 +1,211 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <type_traits>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace gesture_recognizer {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
GestureRecognizerGraphOptions;
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions;
using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarkerGraphOptions;
constexpr char kImageTag[] = "IMAGE";
constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
struct GestureRecognizerOutputs {
Source<std::vector<ClassificationList>> gesture;
Source<std::vector<ClassificationList>> handedness;
Source<std::vector<NormalizedLandmarkList>> hand_landmarks;
Source<std::vector<LandmarkList>> hand_world_landmarks;
Source<Image> image;
};
} // namespace
// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs
// hand gesture recognition.
//
// Inputs:
// IMAGE - Image
// Image to perform hand gesture recognition on.
//
// Outputs:
// HAND_GESTURES - std::vector<ClassificationList>
// Recognized hand gestures with sorted order such that the winning label is
// the first item in the list.
// LANDMARKS: - std::vector<NormalizedLandmarkList>
// Detected hand landmarks.
// WORLD_LANDMARKS - std::vector<LandmarkList>
// Detected hand landmarks in world coordinates.
// HAND_RECT_NEXT_FRAME - std::vector<NormalizedRect>
// The predicted Rect enclosing the hand RoI for landmark detection on the
// next frame.
// HANDEDNESS - std::vector<ClassificationList>
// Classification of handedness.
// IMAGE - mediapipe::Image
// The image that gesture recognizer runs on and has the pixel data stored
// on the target storage (CPU vs GPU).
//
//
// Example:
// node {
// calculator:
// "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"
// input_stream: "IMAGE:image_in"
// output_stream: "HAND_GESTURES:hand_gestures"
// output_stream: "LANDMARKS:hand_landmarks"
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
// output_stream: "HANDEDNESS:handedness"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.vision.gesture_recognizer.proto.GestureRecognizerGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "hand_gesture.tflite"
// }
// }
// hand_landmark_detector_options {
// base_options {
// model_asset {
// file_name: "hand_landmark.tflite"
// }
// }
// }
// }
// }
// }
class GestureRecognizerGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
Graph graph;
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
BuildGestureRecognizerGraph(
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
graph[Input<Image>(kImageTag)], graph));
hand_gesture_recognition_output.gesture >>
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
hand_gesture_recognition_output.handedness >>
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
hand_gesture_recognition_output.hand_landmarks >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_gesture_recognition_output.hand_world_landmarks >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)];
return graph.GetConfig();
}
private:
absl::StatusOr<GestureRecognizerOutputs> BuildGestureRecognizerGraph(
GestureRecognizerGraphOptions& graph_options, Source<Image> image_in,
Graph& graph) {
auto& image_property = graph.AddNode("ImagePropertiesCalculator");
image_in >> image_property.In("IMAGE");
auto image_size = image_property.Out("SIZE");
// Hand landmarker graph.
auto& hand_landmarker_graph = graph.AddNode(
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
auto& hand_landmarker_graph_options =
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
hand_landmarker_graph_options.Swap(
graph_options.mutable_hand_landmarker_graph_options());
image_in >> hand_landmarker_graph.In(kImageTag);
auto hand_landmarks =
hand_landmarker_graph[Output<std::vector<NormalizedLandmarkList>>(
kLandmarksTag)];
auto hand_world_landmarks =
hand_landmarker_graph[Output<std::vector<LandmarkList>>(
kWorldLandmarksTag)];
auto handedness =
hand_landmarker_graph[Output<std::vector<ClassificationList>>(
kHandednessTag)];
auto& vector_indices =
graph.AddNode("NormalizedLandmarkListVectorIndicesCalculator");
hand_landmarks >> vector_indices.In("VECTOR");
auto hand_landmarks_id = vector_indices.Out("INDICES");
// Hand gesture recognizer subgraph.
auto& hand_gesture_subgraph = graph.AddNode(
"mediapipe.tasks.vision.gesture_recognizer."
"MultipleHandGestureRecognizerGraph");
hand_gesture_subgraph.GetOptions<HandGestureRecognizerGraphOptions>().Swap(
graph_options.mutable_hand_gesture_recognizer_graph_options());
hand_landmarks >> hand_gesture_subgraph.In(kLandmarksTag);
hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag);
handedness >> hand_gesture_subgraph.In(kHandednessTag);
image_size >> hand_gesture_subgraph.In(kImageSizeTag);
hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag);
auto hand_gestures =
hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
kHandGesturesTag)];
return {{.gesture = hand_gestures,
.handedness = handedness,
.hand_landmarks = hand_landmarks,
.hand_world_landmarks = hand_world_landmarks,
.image = hand_landmarker_graph[Output<Image>(kImageTag)]}};
}
};
// clang-format off
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::gesture_recognizer::GestureRecognizerGraph); // NOLINT
// clang-format on
} // namespace gesture_recognizer
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -27,7 +27,6 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
@ -36,7 +35,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe { namespace mediapipe {
@ -50,7 +48,8 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::components::processors::
ConfigureTensorsToClassificationCalculator;
using ::mediapipe::tasks::vision::gesture_recognizer::proto:: using ::mediapipe::tasks::vision::gesture_recognizer::proto::
HandGestureRecognizerGraphOptions; HandGestureRecognizerGraphOptions;
@ -95,15 +94,14 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
// The size of image from which the landmarks detected from. // The size of image from which the landmarks detected from.
// //
// Outputs: // Outputs:
// HAND_GESTURES - ClassificationResult // HAND_GESTURES - ClassificationList
// Recognized hand gestures with sorted order such that the winning label is // Recognized hand gestures with sorted order such that the winning label is
// the first item in the list. // the first item in the list.
// //
// //
// Example: // Example:
// node { // node {
// calculator: // calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerGraph"
// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph"
// input_stream: "HANDEDNESS:handedness" // input_stream: "HANDEDNESS:handedness"
// input_stream: "LANDMARKS:landmarks" // input_stream: "LANDMARKS:landmarks"
// input_stream: "WORLD_LANDMARKS:world_landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks"
@ -136,12 +134,12 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<NormalizedLandmarkList>(kLandmarksTag)], graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
graph[Input<LandmarkList>(kWorldLandmarksTag)], graph[Input<LandmarkList>(kWorldLandmarksTag)],
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph)); graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
hand_gestures >> graph[Output<ClassificationResult>(kHandGesturesTag)]; hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph( absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
const HandGestureRecognizerGraphOptions& graph_options, const HandGestureRecognizerGraphOptions& graph_options,
const core::ModelResources& model_resources, const core::ModelResources& model_resources,
Source<ClassificationList> handedness, Source<ClassificationList> handedness,
@ -201,25 +199,24 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
auto concatenated_tensors = concatenate_tensor_vector.Out(""); auto concatenated_tensors = concatenate_tensor_vector.Out("");
// Inference for static hand gesture recognition. // Inference for static hand gesture recognition.
// TODO add embedding step.
auto& inference = AddInference( auto& inference = AddInference(
model_resources, graph_options.base_options().acceleration(), graph); model_resources, graph_options.base_options().acceleration(), graph);
concatenated_tensors >> inference.In(kTensorsTag); concatenated_tensors >> inference.In(kTensorsTag);
auto inference_output_tensors = inference.Out(kTensorsTag); auto inference_output_tensors = inference.Out(kTensorsTag);
auto& postprocessing = graph.AddNode( auto& tensors_to_classification =
"mediapipe.tasks.components.processors." graph.AddNode("TensorsToClassificationCalculator");
"ClassificationPostprocessingGraph"); MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
MP_RETURN_IF_ERROR( graph_options.classifier_options(),
components::processors::ConfigureClassificationPostprocessingGraph( *model_resources.GetMetadataExtractor(), 0,
model_resources, graph_options.classifier_options(), &tensors_to_classification.GetOptions<
&postprocessing mediapipe::TensorsToClassificationCalculatorOptions>()));
.GetOptions<components::processors::proto:: inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
ClassificationPostprocessingGraphOptions>())); auto classification_list =
inference_output_tensors >> postprocessing.In(kTensorsTag); tensors_to_classification[Output<ClassificationList>(
auto classification_result = "CLASSIFICATIONS")];
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")]; return classification_list;
return classification_result;
} }
}; };
@ -247,9 +244,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// index corresponding to the same hand if the graph runs multiple times. // index corresponding to the same hand if the graph runs multiple times.
// //
// Outputs: // Outputs:
// HAND_GESTURES - std::vector<ClassificationResult> // HAND_GESTURES - std::vector<ClassificationList>
// A vector of recognized hand gestures. Each vector element is the // A vector of recognized hand gestures. Each vector element is the
// ClassificationResult of the hand in input vector. // ClassificationList of the hand in input vector.
// //
// //
// Example: // Example:
@ -288,12 +285,12 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph[Input<std::pair<int, int>>(kImageSizeTag)],
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph)); graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
multi_hand_gestures >> multi_hand_gestures >>
graph[Output<std::vector<ClassificationResult>>(kHandGesturesTag)]; graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
private: private:
absl::StatusOr<Source<std::vector<ClassificationResult>>> absl::StatusOr<Source<std::vector<ClassificationList>>>
BuildMultiGestureRecognizerSubraph( BuildMultiGestureRecognizerSubraph(
const HandGestureRecognizerGraphOptions& graph_options, const HandGestureRecognizerGraphOptions& graph_options,
Source<std::vector<ClassificationList>> multi_handedness, Source<std::vector<ClassificationList>> multi_handedness,
@ -346,12 +343,13 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
auto& end_loop_classification_results = auto& end_loop_classification_lists =
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator"); graph.AddNode("EndLoopClassificationListCalculator");
batch_end >> end_loop_classification_results.In(kBatchEndTag); batch_end >> end_loop_classification_lists.In(kBatchEndTag);
hand_gestures >> end_loop_classification_results.In(kItemTag); hand_gestures >> end_loop_classification_lists.In(kItemTag);
auto multi_hand_gestures = end_loop_classification_results auto multi_hand_gestures =
[Output<std::vector<ClassificationResult>>(kIterableTag)]; end_loop_classification_lists[Output<std::vector<ClassificationList>>(
kIterableTag)];
return multi_hand_gestures; return multi_hand_gestures;
} }

View File

@ -31,8 +31,8 @@ mediapipe_proto_library(
) )
mediapipe_proto_library( mediapipe_proto_library(
name = "hand_gesture_recognizer_graph_options_proto", name = "gesture_classifier_graph_options_proto",
srcs = ["hand_gesture_recognizer_graph_options.proto"], srcs = ["gesture_classifier_graph_options.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
@ -40,3 +40,28 @@ mediapipe_proto_library(
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )
mediapipe_proto_library(
name = "hand_gesture_recognizer_graph_options_proto",
srcs = ["hand_gesture_recognizer_graph_options.proto"],
deps = [
":gesture_classifier_graph_options_proto",
":gesture_embedder_graph_options_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)
mediapipe_proto_library(
name = "gesture_recognizer_graph_options_proto",
srcs = ["gesture_recognizer_graph_options.proto"],
deps = [
":hand_gesture_recognizer_graph_options_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_proto",
],
)

View File

@ -0,0 +1,33 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message GestureClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional GestureClassifierGraphOptions ext = 478825465;
}
// Base options for configuring hand gesture recognition subgraph, such as
// specifying the TfLite model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
optional components.processors.proto.ClassifierOptions classifier_options = 2;
}

View File

@ -0,0 +1,43 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto";
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer";
option java_outer_classname = "GestureRecognizerGraphOptionsProto";
message GestureRecognizerGraphOptions {
extend mediapipe.CalculatorOptions {
optional GestureRecognizerGraphOptions ext = 479097054;
}
// Base options for configuring gesture recognizer graph, such as specifying
// the TfLite model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for configuring hand landmarker graph.
optional hand_landmarker.proto.HandLandmarkerGraphOptions
hand_landmarker_graph_options = 2;
// Options for configuring hand gesture recognizer graph.
optional HandGestureRecognizerGraphOptions
hand_gesture_recognizer_graph_options = 3;
}

View File

@ -20,6 +20,8 @@ package mediapipe.tasks.vision.gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
message HandGestureRecognizerGraphOptions { message HandGestureRecognizerGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
@ -29,11 +31,18 @@ message HandGestureRecognizerGraphOptions {
// specifying the TfLite model file with metadata, accelerator options, etc. // specifying the TfLite model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1; optional core.proto.BaseOptions base_options = 1;
// Options for configuring the gesture classifier behavior, such as score // Options for GestureEmbedder.
// threshold, number of results, etc. optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2;
optional components.processors.proto.ClassifierOptions classifier_options = 2;
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be // Options for GestureClassifier of default gestures.
// considered tracked successfully optional GestureClassifierGraphOptions
optional float min_tracking_confidence = 3 [default = 0.0]; canned_gesture_classifier_graph_options = 3;
// Options for GestureClassifier of custom gestures.
optional GestureClassifierGraphOptions
custom_gesture_classifier_graph_options = 4;
// TODO: remove these. Temporary solutions before bundle asset is
// ready.
optional components.processors.proto.ClassifierOptions classifier_options = 5;
} }

View File

@ -80,6 +80,7 @@ cc_library(
"//mediapipe/calculators/core:gate_calculator_cc_proto", "//mediapipe/calculators/core:gate_calculator_cc_proto",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/calculators/core:previous_loopback_calculator", "//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator", "//mediapipe/calculators/util:collection_has_min_size_calculator",
"//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto", "//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
@ -98,6 +99,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_landmarks_deduplication_calculator",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
], ],

View File

@ -15,7 +15,6 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [ package(default_visibility = [
"//mediapipe/app/xeno:__subpackages__",
"//mediapipe/tasks:internal", "//mediapipe/tasks:internal",
]) ])
@ -46,4 +45,26 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
# TODO: Enable this test cc_library(
name = "hand_landmarks_deduplication_calculator",
srcs = ["hand_landmarks_deduplication_calculator.cc"],
hdrs = ["hand_landmarks_deduplication_calculator.h"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:landmarks_detection",
"//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder",
"//mediapipe/tasks/cc/vision/utils:landmarks_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:optional",
],
alwayslink = 1,
)

View File

@ -0,0 +1,310 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h"
#include <algorithm>
#include <cmath>
#include <limits>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h"
#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h"
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
namespace mediapipe::api2 {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::Bound;
using ::mediapipe::tasks::vision::utils::CalculateIOU;
using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
float Distance(const NormalizedLandmark& lm_a, const NormalizedLandmark& lm_b,
int width, int height) {
return std::sqrt(std::pow((lm_a.x() - lm_b.x()) * width, 2) +
std::pow((lm_a.y() - lm_b.y()) * height, 2));
}
absl::StatusOr<std::vector<float>> Distances(const NormalizedLandmarkList& a,
const NormalizedLandmarkList& b,
int width, int height) {
const int num = a.landmark_size();
RET_CHECK_EQ(b.landmark_size(), num);
std::vector<float> distances;
distances.reserve(num);
for (int i = 0; i < num; ++i) {
const NormalizedLandmark& lm_a = a.landmark(i);
const NormalizedLandmark& lm_b = b.landmark(i);
distances.push_back(Distance(lm_a, lm_b, width, height));
}
return distances;
}
// Calculates a baseline distance of a hand that can be used as a relative
// measure when calculating hand to hand similarity.
//
// Calculated as maximum of distances: 0->5, 5->17, 17->0, where 0, 5, 17 key
// points are depicted below:
//
// /Middle/
// |
// /Index/ | /Ring/
// | | | /Pinky/
// V V V |
// V
// [8] [12] [16]
// | | | [20]
// | | | |
// /Thumb/ | | | |
// | [7] [11] [15] [19]
// V | | | |
// | | | |
// [4] | | | |
// | [6] [10] [14] [18]
// | | | | |
// | | | | |
// [3] | | | |
// | [5]----[9]---[13]---[17]
// . | |
// \ . |
// \ / |
// [2] |
// \ |
// \ |
// \ |
// [1] .
// \ /
// \ /
// ._____[0]_____.
//
// ^
// |
// /Wrist/
absl::StatusOr<float> HandBaselineDistance(
const NormalizedLandmarkList& landmarks, int width, int height) {
RET_CHECK_EQ(landmarks.landmark_size(), 21); // Num of hand landmarks.
constexpr int kWrist = 0;
constexpr int kIndexFingerMcp = 5;
constexpr int kPinkyMcp = 17;
float distance = Distance(landmarks.landmark(kWrist),
landmarks.landmark(kIndexFingerMcp), width, height);
distance = std::max(distance,
Distance(landmarks.landmark(kIndexFingerMcp),
landmarks.landmark(kPinkyMcp), width, height));
distance =
std::max(distance, Distance(landmarks.landmark(kPinkyMcp),
landmarks.landmark(kWrist), width, height));
return distance;
}
Bound CalculateBound(const NormalizedLandmarkList& list) {
constexpr float kMinInitialValue = std::numeric_limits<float>::max();
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
// Compute min and max values on landmarks (they will form
// bounding box)
float bounding_box_left = kMinInitialValue;
float bounding_box_top = kMinInitialValue;
float bounding_box_right = kMaxInitialValue;
float bounding_box_bottom = kMaxInitialValue;
for (const auto& landmark : list.landmark()) {
bounding_box_left = std::min(bounding_box_left, landmark.x());
bounding_box_top = std::min(bounding_box_top, landmark.y());
bounding_box_right = std::max(bounding_box_right, landmark.x());
bounding_box_bottom = std::max(bounding_box_bottom, landmark.y());
}
// Populate normalized non rotated face bounding box
return {.left = bounding_box_left,
.top = bounding_box_top,
.right = bounding_box_right,
.bottom = bounding_box_bottom};
}
// Uses IoU and distance of some corresponding hand landmarks to detect
// duplicate / similar hands. IoU, distance thresholds, number of landmarks to
// match are found experimentally. Evaluated:
// - manually comparing side by side, before and after deduplication applied
// - generating gesture dataset, and checking select frames in baseline and
// "deduplicated" dataset
// - by confirming gesture training is better with use of deduplication using
// selected thresholds
class HandDuplicatesFinder : public DuplicatesFinder {
public:
explicit HandDuplicatesFinder(bool start_from_the_end)
: start_from_the_end_(start_from_the_end) {}
absl::StatusOr<absl::flat_hash_set<int>> FindDuplicates(
const std::vector<NormalizedLandmarkList>& multi_landmarks,
int input_width, int input_height) override {
absl::flat_hash_set<int> retained_indices;
absl::flat_hash_set<int> suppressed_indices;
const int num = multi_landmarks.size();
std::vector<float> baseline_distances;
baseline_distances.reserve(num);
std::vector<Bound> bounds;
bounds.reserve(num);
for (const NormalizedLandmarkList& list : multi_landmarks) {
ASSIGN_OR_RETURN(const float baseline_distance,
HandBaselineDistance(list, input_width, input_height));
baseline_distances.push_back(baseline_distance);
bounds.push_back(CalculateBound(list));
}
for (int index = 0; index < num; ++index) {
const int i = start_from_the_end_ ? num - index - 1 : index;
const float stable_distance_i = baseline_distances[i];
bool suppressed = false;
for (int j : retained_indices) {
const float stable_distance_j = baseline_distances[j];
constexpr float kAllowedBaselineDistanceRatio = 0.2f;
const float distance_threshold =
std::max(stable_distance_i, stable_distance_j) *
kAllowedBaselineDistanceRatio;
ASSIGN_OR_RETURN(const std::vector<float> distances,
Distances(multi_landmarks[i], multi_landmarks[j],
input_width, input_height));
const int num_matched_landmarks = absl::c_count_if(
distances,
[&](float distance) { return distance < distance_threshold; });
const float iou = CalculateIOU(bounds[i], bounds[j]);
constexpr int kNumMatchedLandmarksToSuppressHand = 10; // out of 21
constexpr float kMinIouThresholdToSuppressHand = 0.2f;
if (num_matched_landmarks >= kNumMatchedLandmarksToSuppressHand &&
iou > kMinIouThresholdToSuppressHand) {
suppressed = true;
break;
}
}
if (suppressed) {
suppressed_indices.insert(i);
} else {
retained_indices.insert(i);
}
}
return suppressed_indices;
}
private:
const bool start_from_the_end_;
};
template <typename InputPortT>
absl::StatusOr<absl::optional<typename InputPortT::PayloadT>>
VerifyNumAndMaybeInitOutput(const InputPortT& port, CalculatorContext* cc,
int num_expected_size) {
absl::optional<typename InputPortT::PayloadT> output;
if (port(cc).IsConnected() && !port(cc).IsEmpty()) {
RET_CHECK_EQ(port(cc).Get().size(), num_expected_size);
typename InputPortT::PayloadT result;
return {{result}};
}
return {absl::nullopt};
}
} // namespace
std::unique_ptr<DuplicatesFinder> CreateHandDuplicatesFinder(
bool start_from_the_end) {
return absl::make_unique<HandDuplicatesFinder>(start_from_the_end);
}
absl::Status HandLandmarksDeduplicationCalculator::Process(
mediapipe::CalculatorContext* cc) {
if (kInLandmarks(cc).IsEmpty()) return absl::OkStatus();
if (kInSize(cc).IsEmpty()) return absl::OkStatus();
const std::vector<NormalizedLandmarkList>& in_landmarks = *kInLandmarks(cc);
const std::pair<int, int>& image_size = *kInSize(cc);
std::unique_ptr<DuplicatesFinder> duplicates_finder =
CreateHandDuplicatesFinder(/*start_from_the_end=*/false);
ASSIGN_OR_RETURN(absl::flat_hash_set<int> indices_to_remove,
duplicates_finder->FindDuplicates(
in_landmarks, image_size.first, image_size.second));
if (indices_to_remove.empty()) {
kOutLandmarks(cc).Send(kInLandmarks(cc));
kOutRois(cc).Send(kInRois(cc));
kOutWorldLandmarks(cc).Send(kInWorldLandmarks(cc));
kOutClassifications(cc).Send(kInClassifications(cc));
} else {
std::vector<NormalizedLandmarkList> out_landmarks;
const int num = in_landmarks.size();
ASSIGN_OR_RETURN(absl::optional<std::vector<NormalizedRect>> out_rois,
VerifyNumAndMaybeInitOutput(kInRois, cc, num));
ASSIGN_OR_RETURN(
absl::optional<std::vector<LandmarkList>> out_world_landmarks,
VerifyNumAndMaybeInitOutput(kInWorldLandmarks, cc, num));
ASSIGN_OR_RETURN(
absl::optional<std::vector<ClassificationList>> out_classifications,
VerifyNumAndMaybeInitOutput(kInClassifications, cc, num));
for (int i = 0; i < num; ++i) {
if (indices_to_remove.find(i) != indices_to_remove.end()) continue;
out_landmarks.push_back(in_landmarks[i]);
if (out_rois) {
out_rois->push_back(kInRois(cc).Get()[i]);
}
if (out_world_landmarks) {
out_world_landmarks->push_back(kInWorldLandmarks(cc).Get()[i]);
}
if (out_classifications) {
out_classifications->push_back(kInClassifications(cc).Get()[i]);
}
}
if (!out_landmarks.empty()) {
kOutLandmarks(cc).Send(std::move(out_landmarks));
}
if (out_rois && !out_rois->empty()) {
kOutRois(cc).Send(std::move(out_rois.value()));
}
if (out_world_landmarks && !out_world_landmarks->empty()) {
kOutWorldLandmarks(cc).Send(std::move(out_world_landmarks.value()));
}
if (out_classifications && !out_classifications->empty()) {
kOutClassifications(cc).Send(std::move(out_classifications.value()));
}
}
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(HandLandmarksDeduplicationCalculator);
} // namespace mediapipe::api2

View File

@ -0,0 +1,97 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h"
namespace mediapipe::api2 {
// Create a DuplicatesFinder dedicated for finding hand duplications.
std::unique_ptr<tasks::vision::utils::DuplicatesFinder>
CreateHandDuplicatesFinder(bool start_from_the_end = false);
// Filter duplicate hand landmarks by finding the overlapped hands.
// Inputs:
// MULTI_LANDMARKS - std::vector<NormalizedLandmarkList>
// The hand landmarks to be filtered.
// MULTI_ROIS - std::vector<NormalizedRect>
// The regions where each encloses the landmarks of a single hand.
// MULTI_WORLD_LANDMARKS - std::vector<LandmarkList>
// The hand landmarks to be filtered in world coordinates.
// MULTI_CLASSIFICATIONS - std::vector<ClassificationList>
// The handedness of hands.
// IMAGE_SIZE - std::pair<int, int>
// The size of the image which the hand landmarks are detected on.
//
// Outputs:
// MULTI_LANDMARKS - std::vector<NormalizedLandmarkList>
// The hand landmarks with duplication removed.
// MULTI_ROIS - std::vector<NormalizedRect>
// The regions where each encloses the landmarks of a single hand with
// duplicate hands removed.
// MULTI_WORLD_LANDMARKS - std::vector<LandmarkList>
// The hand landmarks with duplication removed in world coordinates.
// MULTI_CLASSIFICATIONS - std::vector<ClassificationList>
// The handedness of hands with duplicate hands removed.
//
// Example:
// node {
// calculator: "HandLandmarksDeduplicationCalculator"
// input_stream: "MULTI_LANDMARKS:landmarks_in"
// input_stream: "MULTI_ROIS:rois_in"
// input_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_in"
// input_stream: "MULTI_CLASSIFICATIONS:handedness_in"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "MULTI_LANDMARKS:landmarks_out"
// output_stream: "MULTI_ROIS:rois_out"
// output_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_out"
// output_stream: "MULTI_CLASSIFICATIONS:handedness_out"
// }
class HandLandmarksDeduplicationCalculator : public Node {
public:
constexpr static Input<std::vector<mediapipe::NormalizedLandmarkList>>
kInLandmarks{"MULTI_LANDMARKS"};
constexpr static Input<std::vector<mediapipe::NormalizedRect>>::Optional
kInRois{"MULTI_ROIS"};
constexpr static Input<std::vector<mediapipe::LandmarkList>>::Optional
kInWorldLandmarks{"MULTI_WORLD_LANDMARKS"};
constexpr static Input<std::vector<mediapipe::ClassificationList>>::Optional
kInClassifications{"MULTI_CLASSIFICATIONS"};
constexpr static Input<std::pair<int, int>> kInSize{"IMAGE_SIZE"};
constexpr static Output<std::vector<mediapipe::NormalizedLandmarkList>>
kOutLandmarks{"MULTI_LANDMARKS"};
constexpr static Output<std::vector<mediapipe::NormalizedRect>>::Optional
kOutRois{"MULTI_ROIS"};
constexpr static Output<std::vector<mediapipe::LandmarkList>>::Optional
kOutWorldLandmarks{"MULTI_WORLD_LANDMARKS"};
constexpr static Output<std::vector<mediapipe::ClassificationList>>::Optional
kOutClassifications{"MULTI_CLASSIFICATIONS"};
MEDIAPIPE_NODE_CONTRACT(kInLandmarks, kInRois, kInWorldLandmarks,
kInClassifications, kInSize, kOutLandmarks, kOutRois,
kOutWorldLandmarks, kOutClassifications);
absl::Status Process(mediapipe::CalculatorContext* cc) override;
};
} // namespace mediapipe::api2
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_

View File

@ -247,11 +247,37 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
image_in >> hand_landmarks_detector_graph.In("IMAGE"); image_in >> hand_landmarks_detector_graph.In("IMAGE");
clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT"); clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT");
auto landmarks = hand_landmarks_detector_graph.Out(kLandmarksTag);
auto world_landmarks =
hand_landmarks_detector_graph.Out(kWorldLandmarksTag);
auto hand_rects_for_next_frame = auto hand_rects_for_next_frame =
hand_landmarks_detector_graph[Output<std::vector<NormalizedRect>>( hand_landmarks_detector_graph.Out(kHandRectNextFrameTag);
kHandRectNextFrameTag)]; auto handedness = hand_landmarks_detector_graph.Out(kHandednessTag);
auto& image_property = graph.AddNode("ImagePropertiesCalculator");
image_in >> image_property.In("IMAGE");
auto image_size = image_property.Out("SIZE");
auto& deduplicate = graph.AddNode("HandLandmarksDeduplicationCalculator");
landmarks >> deduplicate.In("MULTI_LANDMARKS");
world_landmarks >> deduplicate.In("MULTI_WORLD_LANDMARKS");
hand_rects_for_next_frame >> deduplicate.In("MULTI_ROIS");
handedness >> deduplicate.In("MULTI_CLASSIFICATIONS");
image_size >> deduplicate.In("IMAGE_SIZE");
auto filtered_landmarks =
deduplicate[Output<std::vector<NormalizedLandmarkList>>(
"MULTI_LANDMARKS")];
auto filtered_world_landmarks =
deduplicate[Output<std::vector<LandmarkList>>("MULTI_WORLD_LANDMARKS")];
auto filtered_hand_rects_for_next_frame =
deduplicate[Output<std::vector<NormalizedRect>>("MULTI_ROIS")];
auto filtered_handedness =
deduplicate[Output<std::vector<ClassificationList>>(
"MULTI_CLASSIFICATIONS")];
// Back edge. // Back edge.
hand_rects_for_next_frame >> previous_loopback.In("LOOP"); filtered_hand_rects_for_next_frame >> previous_loopback.In("LOOP");
// TODO: Replace PassThroughCalculator with a calculator that // TODO: Replace PassThroughCalculator with a calculator that
// converts the pixel data to be stored on the target storage (CPU vs GPU). // converts the pixel data to be stored on the target storage (CPU vs GPU).
@ -259,14 +285,10 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
image_in >> pass_through.In(""); image_in >> pass_through.In("");
return {{ return {{
/* landmark_lists= */ hand_landmarks_detector_graph /* landmark_lists= */ filtered_landmarks,
[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)], /* world_landmark_lists= */ filtered_world_landmarks,
/* world_landmark_lists= */ /* hand_rects_next_frame= */ filtered_hand_rects_for_next_frame,
hand_landmarks_detector_graph[Output<std::vector<LandmarkList>>( /* handedness= */ filtered_handedness,
kWorldLandmarksTag)],
/* hand_rects_next_frame= */ hand_rects_for_next_frame,
hand_landmarks_detector_graph[Output<std::vector<ClassificationList>>(
kHandednessTag)],
/* palm_rects= */ /* palm_rects= */
hand_detector[Output<std::vector<NormalizedRect>>(kPalmRectsTag)], hand_detector[Output<std::vector<NormalizedRect>>(kPalmRectsTag)],
/* palm_detections */ /* palm_detections */

View File

@ -208,7 +208,7 @@ TEST_F(CreateTest, FailsWithMissingModel) {
EXPECT_THAT( EXPECT_THAT(
image_classifier.status().message(), image_classifier.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', " HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name' or 'file_descriptor_meta'.")); "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));

View File

@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.imageclassifier.proto";
option java_outer_classname = "ImageClassifierGraphOptionsProto";
message ImageClassifierGraphOptions { message ImageClassifierGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ImageClassifierGraphOptions ext = 456383383; optional ImageClassifierGraphOptions ext = 456383383;

View File

@ -140,7 +140,7 @@ TEST_F(CreateTest, FailsWithMissingModel) {
EXPECT_THAT( EXPECT_THAT(
image_embedder.status().message(), image_embedder.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', " HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name' or 'file_descriptor_meta'.")); "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));

View File

@ -191,7 +191,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
EXPECT_THAT( EXPECT_THAT(
segmenter_or.status().message(), segmenter_or.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', " HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name' or 'file_descriptor_meta'.")); "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));

View File

@ -208,7 +208,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
EXPECT_THAT( EXPECT_THAT(
object_detector.status().message(), object_detector.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', " HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name' or 'file_descriptor_meta'.")); "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(object_detector.status().GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(object_detector.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat( Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError)))); MediaPipeTasksStatus::kRunnerInitializationError))));

View File

@ -79,3 +79,30 @@ cc_library(
"@stblib//:stb_image", "@stblib//:stb_image",
], ],
) )
cc_library(
name = "landmarks_duplicates_finder",
hdrs = ["landmarks_duplicates_finder.h"],
deps = [
"//mediapipe/framework/formats:landmark_cc_proto",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status:statusor",
],
)
cc_library(
name = "landmarks_utils",
srcs = ["landmarks_utils.cc"],
hdrs = ["landmarks_utils.h"],
deps = ["//mediapipe/tasks/cc/components/containers:landmarks_detection"],
)
cc_test(
name = "landmarks_utils_test",
srcs = ["landmarks_utils_test.cc"],
deps = [
":landmarks_utils",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/components/containers:landmarks_detection",
],
)

View File

@ -0,0 +1,40 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_
#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe::tasks::vision::utils {
class DuplicatesFinder {
public:
virtual ~DuplicatesFinder() = default;
// Returns indices of landmark lists to remove to make @multi_landmarks
// contain different enough (depending on the implementation) landmark lists
// only.
virtual absl::StatusOr<absl::flat_hash_set<int>> FindDuplicates(
const std::vector<mediapipe::NormalizedLandmarkList>& multi_landmarks,
int input_width, int input_height) = 0;
};
} // namespace mediapipe::tasks::vision::utils
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_

View File

@ -0,0 +1,48 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
#include <algorithm>
#include <vector>
namespace mediapipe::tasks::vision::utils {
using ::mediapipe::tasks::components::containers::Bound;
float CalculateArea(const Bound& bound) {
return (bound.right - bound.left) * (bound.bottom - bound.top);
}
float CalculateIntersectionArea(const Bound& a, const Bound& b) {
const float intersection_left = std::max<float>(a.left, b.left);
const float intersection_top = std::max<float>(a.top, b.top);
const float intersection_right = std::min<float>(a.right, b.right);
const float intersection_bottom = std::min<float>(a.bottom, b.bottom);
return std::max<float>(intersection_bottom - intersection_top, 0.0) *
std::max<float>(intersection_right - intersection_left, 0.0);
}
float CalculateIOU(const Bound& a, const Bound& b) {
const float area_a = CalculateArea(a);
const float area_b = CalculateArea(b);
if (area_a <= 0 || area_b <= 0) return 0.0;
const float intersection_area = CalculateIntersectionArea(a, b);
return intersection_area / (area_a + area_b - intersection_area);
}
} // namespace mediapipe::tasks::vision::utils

View File

@ -0,0 +1,41 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
#include <algorithm>
#include <array>
#include <cmath>
#include <limits>
#include <vector>
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h"
namespace mediapipe::tasks::vision::utils {
// Calculates intersection over union for two bounds.
float CalculateIOU(const components::containers::Bound& a,
const components::containers::Bound& b);
// Calculates area for face bound
float CalculateArea(const components::containers::Bound& bound);
// Calucates intersection area of two face bounds
float CalculateIntersectionArea(const components::containers::Bound& a,
const components::containers::Bound& b);
} // namespace mediapipe::tasks::vision::utils
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_

View File

@ -0,0 +1,41 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe::tasks::vision::utils {
namespace {
TEST(LandmarkUtilsTest, CalculateIOU) {
// Do not intersect
EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 2, 3, 3}));
// No x intersection
EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 0, 3, 1}));
// No y intersection
EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {0, 2, 1, 3}));
// Full intersection
EXPECT_EQ(1, CalculateIOU({0, 0, 2, 2}, {0, 0, 2, 2}));
// Union is 4 intersection is 1
EXPECT_EQ(0.25, CalculateIOU({0, 0, 3, 1}, {2, 0, 4, 1}));
// Same in by y
EXPECT_EQ(0.25, CalculateIOU({0, 0, 1, 3}, {0, 2, 1, 4}));
}
} // namespace
} // namespace mediapipe::tasks::vision::utils

View File

@ -34,3 +34,32 @@ android_library(
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
android_library(
name = "classification_entry",
srcs = ["ClassificationEntry.java"],
deps = [
":category",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "classifications",
srcs = ["Classifications.java"],
deps = [
":classification_entry",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "landmark",
srcs = ["Landmark.java"],
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,48 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents a list of predicted categories with an optional timestamp. Typically used as result
* for classification tasks.
*/
@AutoValue
public abstract class ClassificationEntry {
/**
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
* timestamp.
*
* @param categories the list of {@link Category} objects that contain category name, display
* name, score and label index.
* @param timestampMs the {@link long} representing the timestamp for which these categories were
* obtained.
*/
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
}
/** The list of predicted {@link Category} objects, sorted by descending score. */
public abstract List<Category> categories();
/**
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
* series use cases, e.g. audio classification.
*/
public abstract long timestampMs();
}

View File

@ -0,0 +1,52 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents the list of classification for a given classifier head. Typically used as a result for
* classification tasks.
*/
@AutoValue
public abstract class Classifications {
/**
* Creates a {@link Classifications} instance.
*
* @param entries the list of {@link ClassificationEntry} objects containing the predicted
* categories.
* @param headIndex the index of the classifier head.
* @param headName the name of the classifier head.
*/
public static Classifications create(
List<ClassificationEntry> entries, int headIndex, String headName) {
return new AutoValue_Classifications(
Collections.unmodifiableList(entries), headIndex, headName);
}
/** A list of {@link ClassificationEntry} objects. */
public abstract List<ClassificationEntry> entries();
/**
* The index of the classifier head these entries refer to. This is useful for multi-head models.
*/
public abstract int headIndex();
/** The name of the classifier head, which is the corresponding tensor metadata name. */
public abstract String headName();
}

View File

@ -0,0 +1,66 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Objects;
/**
* Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the
* landmark coordinates is normalized respect to the dimension of image, and the coordinates values
* are in the range of [0,1]. Otherwise, it represenet a point in world coordinates.
*/
@AutoValue
public abstract class Landmark {
private static final float TOLERANCE = 1e-6f;
public static Landmark create(float x, float y, float z, boolean normalized) {
return new AutoValue_Landmark(x, y, z, normalized);
}
// The x coordniates of the landmark.
public abstract float x();
// The y coordniates of the landmark.
public abstract float y();
// The z coordniates of the landmark.
public abstract float z();
// Whether this landmark is normalized with respect to the image size.
public abstract boolean normalized();
@Override
public final boolean equals(Object o) {
if (!(o instanceof Landmark)) {
return false;
}
Landmark other = (Landmark) o;
return other.normalized() == this.normalized()
&& Math.abs(other.x() - this.x()) < TOLERANCE
&& Math.abs(other.x() - this.y()) < TOLERANCE
&& Math.abs(other.x() - this.z()) < TOLERANCE;
}
@Override
public final int hashCode() {
return Objects.hash(x(), y(), z(), normalized());
}
@Override
public final String toString() {
return "<Landmark (x=" + x() + " y=" + y() + " z=" + z() + " normalized=" + normalized() + ")>";
}
}

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "classifieroptions",
srcs = ["ClassifierOptions.java"],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,118 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.processors;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/** Classifier options shared across MediaPipe Java classification tasks. */
@AutoValue
public abstract class ClassifierOptions {
/** Builder for {@link ClassifierOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets the optional locale to use for display names specified through the TFLite Model
* Metadata, if any.
*/
public abstract Builder setDisplayNamesLocale(String locale);
/**
* Sets the optional maximum number of top-scored classification results to return.
*
* <p>If not set, all available results are returned. If set, must be > 0.
*/
public abstract Builder setMaxResults(Integer maxResults);
/**
* Sets the optional score threshold. Results with score below this value are rejected.
*
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
*/
public abstract Builder setScoreThreshold(Float scoreThreshold);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
abstract ClassifierOptions autoBuild();
/**
* Validates and builds the {@link ClassifierOptions} instance.
*
* @throws IllegalArgumentException if {@link maxResults} is set to a value <= 0.
*/
public final ClassifierOptions build() {
ClassifierOptions options = autoBuild();
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
throw new IllegalArgumentException("If specified, maxResults must be > 0");
}
return options;
}
}
public abstract Optional<String> displayNamesLocale();
public abstract Optional<Integer> maxResults();
public abstract Optional<Float> scoreThreshold();
public abstract List<String> categoryAllowlist();
public abstract List<String> categoryDenylist();
public static Builder builder() {
return new AutoValue_ClassifierOptions.Builder()
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
}
/**
* Converts a {@link ClassifierOptions} object to a {@link
* ClassifierOptionsProto.ClassifierOptions} protobuf message.
*/
public ClassifierOptionsProto.ClassifierOptions convertToProto() {
ClassifierOptionsProto.ClassifierOptions.Builder builder =
ClassifierOptionsProto.ClassifierOptions.newBuilder();
displayNamesLocale().ifPresent(builder::setDisplayNamesLocale);
maxResults().ifPresent(builder::setMaxResults);
scoreThreshold().ifPresent(builder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
builder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
builder.addAllCategoryDenylist(categoryDenylist());
}
return builder.build();
}
}

View File

@ -19,8 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
android_library( android_library(
name = "core", name = "core",
srcs = glob(["*.java"]), srcs = glob(["*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [ deps = [
":libmediapipe_tasks_vision_jni_lib", ":libmediapipe_tasks_vision_jni_lib",
"//mediapipe/framework/formats:rect_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff", "//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
@ -36,6 +40,7 @@ cc_binary(
deps = [ deps = [
"//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
], ],

View File

@ -14,101 +14,247 @@
package com.google.mediapipe.tasks.vision.core; package com.google.mediapipe.tasks.vision.core;
import android.graphics.RectF;
import com.google.mediapipe.formats.proto.RectProto.NormalizedRect;
import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.TaskResult; import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner; import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Optional;
/** The base class of MediaPipe vision tasks. */ /** The base class of MediaPipe vision tasks. */
public class BaseVisionTaskApi implements AutoCloseable { public class BaseVisionTaskApi implements AutoCloseable {
private static final long MICROSECONDS_PER_MILLISECOND = 1000; private static final long MICROSECONDS_PER_MILLISECOND = 1000;
private final TaskRunner runner; private final TaskRunner runner;
private final RunningMode runningMode; private final RunningMode runningMode;
private final String imageStreamName;
private final Optional<String> normRectStreamName;
static { static {
System.loadLibrary("mediapipe_tasks_vision_jni"); System.loadLibrary("mediapipe_tasks_vision_jni");
ProtoUtil.registerTypeName(NormalizedRect.class, "mediapipe.NormalizedRect");
} }
/** /**
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision * Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
* task {@link RunningMode}.
* *
* @param runner a {@link TaskRunner}. * @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}. * @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
*/ */
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) { public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
this.runner = runner; this.runner = runner;
this.runningMode = runningMode; this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.empty();
}
/**
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
* input.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
* @param imageStreamName the name of the input image stream.
* @param normRectStreamName the name of the input normalized rect image stream.
*/
public BaseVisionTaskApi(
TaskRunner runner,
RunningMode runningMode,
String imageStreamName,
String normRectStreamName) {
this.runner = runner;
this.runningMode = runningMode;
this.imageStreamName = imageStreamName;
this.normRectStreamName = Optional.of(normRectStreamName);
} }
/** /**
* A synchronous method to process single image inputs. The call blocks the current thread until a * A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned. * failure status or a successful result is returned.
* *
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if the task is not in the image mode. * @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
* input.
*/ */
protected TaskResult processImageData(String imageStreamName, Image image) { protected TaskResult processImageData(Image image) {
if (runningMode != RunningMode.IMAGE) { if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:" "Task is not initialized with the image mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets); return runner.process(inputPackets);
} }
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
* rect.
*/
protected TaskResult processImageData(Image image, RectF roi) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
return runner.process(inputPackets);
}
/** /**
* A synchronous method to process continuous video frames. The call blocks the current thread * A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned. * until a failure status or a successful result is returned.
* *
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds. * @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode. * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
*/ */
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) { protected TaskResult processVideoData(Image image, long timestampMs) {
if (runningMode != RunningMode.VIDEO) { if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:" "Task is not initialized with the video mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** /**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be * An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener. * available in the user-defined result listener.
* *
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing. * @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds. * @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode. * @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
* input.
*/ */
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) { protected void sendLiveStreamData(Image image, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) { if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:" "Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name()); + runningMode.name());
} }
if (normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task expects a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>(); Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image)); inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND); runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
} }
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param image a MediaPipe {@link Image} object for processing.
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
* are expected to be specified as normalized values in [0,1].
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
* rect.
*/
protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
if (!normRectStreamName.isPresent()) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task doesn't expect a normalized rect as input.");
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
inputPackets.put(
normRectStreamName.get(),
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** Closes and cleans up the MediaPipe vision task. */ /** Closes and cleans up the MediaPipe vision task. */
@Override @Override
public void close() { public void close() {
runner.close(); runner.close();
} }
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
private static NormalizedRect convertToNormalizedRect(RectF rect) {
return NormalizedRect.newBuilder()
.setXCenter(rect.centerX())
.setYCenter(rect.centerY())
.setWidth(rect.width())
.setHeight(rect.height())
.build();
}
} }

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.gesturerecognizer">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,40 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "gesturerecognizer",
srcs = [
"GestureRecognitionResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = ":AndroidManifest.xml",
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,128 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.gesturerecognizer;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto.Landmark;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the gesture recognition results generated by {@link GestureRecognizer}. */
@AutoValue
public abstract class GestureRecognitionResult implements TaskResult {
/**
* Creates a {@link GestureRecognitionResult} instance from the lists of landmarks, handedness,
* and gestures protobuf messages.
*
* @param landmarksProto a List of {@link NormalizedLandmarkList}
* @param worldLandmarksProto a List of {@link LandmarkList}
* @param handednessesProto a List of {@link ClassificationList}
* @param gesturesProto a List of {@link ClassificationList}
*/
static GestureRecognitionResult create(
List<NormalizedLandmarkList> landmarksProto,
List<LandmarkList> worldLandmarksProto,
List<ClassificationList> handednessesProto,
List<ClassificationList> gesturesProto,
long timestampMs) {
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandLandmarks =
new ArrayList<>();
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandWorldLandmarks =
new ArrayList<>();
List<List<Category>> multiHandHandednesses = new ArrayList<>();
List<List<Category>> multiHandGestures = new ArrayList<>();
for (NormalizedLandmarkList handLandmarksProto : landmarksProto) {
List<com.google.mediapipe.tasks.components.containers.Landmark> handLandmarks =
new ArrayList<>();
multiHandLandmarks.add(handLandmarks);
for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) {
handLandmarks.add(
com.google.mediapipe.tasks.components.containers.Landmark.create(
handLandmarkProto.getX(),
handLandmarkProto.getY(),
handLandmarkProto.getZ(),
true));
}
}
for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) {
List<com.google.mediapipe.tasks.components.containers.Landmark> handWorldLandmarks =
new ArrayList<>();
multiHandWorldLandmarks.add(handWorldLandmarks);
for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) {
handWorldLandmarks.add(
com.google.mediapipe.tasks.components.containers.Landmark.create(
handWorldLandmarkProto.getX(),
handWorldLandmarkProto.getY(),
handWorldLandmarkProto.getZ(),
false));
}
}
for (ClassificationList handednessProto : handednessesProto) {
List<Category> handedness = new ArrayList<>();
multiHandHandednesses.add(handedness);
for (Classification classification : handednessProto.getClassificationList()) {
handedness.add(
Category.create(
classification.getScore(),
classification.getIndex(),
classification.getLabel(),
classification.getDisplayName()));
}
}
for (ClassificationList gestureProto : gesturesProto) {
List<Category> gestures = new ArrayList<>();
multiHandGestures.add(gestures);
for (Classification classification : gestureProto.getClassificationList()) {
gestures.add(
Category.create(
classification.getScore(),
classification.getIndex(),
classification.getLabel(),
classification.getDisplayName()));
}
}
return new AutoValue_GestureRecognitionResult(
timestampMs,
Collections.unmodifiableList(multiHandLandmarks),
Collections.unmodifiableList(multiHandWorldLandmarks),
Collections.unmodifiableList(multiHandHandednesses),
Collections.unmodifiableList(multiHandGestures));
}
@Override
public abstract long timestampMs();
/** Hand landmarks of detected hands. */
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>> landmarks();
/** Hand landmarks in world coordniates of detected hands. */
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>>
worldLandmarks();
/** Handedness of detected hands. */
public abstract List<List<Category>> handednesses();
/** Recognized hand gestures of detected hands */
public abstract List<List<Category>> gestures();
}

View File

@ -38,7 +38,8 @@ public abstract class ObjectDetectionResult implements TaskResult {
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf * Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
* messages. * messages.
* *
* @param detectionList a list of {@link Detection} protobuf messages. * @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
* @param timestampMs a timestamp for this result.
*/ */
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) { static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>(); List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();

View File

@ -155,7 +155,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}. * Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
* *
* @param context an Android {@link Context}. * @param context an Android {@link Context}.
* @param detectorOptions a {@link ObjectDetectorOptions} instance. * @param detectorOptions an {@link ObjectDetectorOptions} instance.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation. * @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/ */
public static ObjectDetector createFromOptions( public static ObjectDetector createFromOptions(
@ -192,7 +192,6 @@ public final class ObjectDetector extends BaseVisionTaskApi {
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM) .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(), .build(),
handler); handler);
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
return new ObjectDetector(runner, detectorOptions.runningMode()); return new ObjectDetector(runner, detectorOptions.runningMode());
} }
@ -204,7 +203,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @param runningMode a mediapipe vision task {@link RunningMode}. * @param runningMode a mediapipe vision task {@link RunningMode}.
*/ */
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) { private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode); super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
} }
/** /**
@ -221,7 +220,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detect(Image inputImage) { public ObjectDetectionResult detect(Image inputImage) {
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage); return (ObjectDetectionResult) processImageData(inputImage);
} }
/** /**
@ -242,8 +241,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) { public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
return (ObjectDetectionResult) return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs);
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
} }
/** /**
@ -265,7 +263,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public void detectAsync(Image inputImage, long inputTimestampMs) { public void detectAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs); sendLiveStreamData(inputImage, inputTimestampMs);
} }
/** Options for setting up an {@link ObjectDetector}. */ /** Options for setting up an {@link ObjectDetector}. */
@ -275,12 +273,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
/** Builder for {@link ObjectDetectorOptions}. */ /** Builder for {@link ObjectDetectorOptions}. */
@AutoValue.Builder @AutoValue.Builder
public abstract static class Builder { public abstract static class Builder {
/** Sets the base options for the object detector task. */ /** Sets the {@link BaseOptions} for the object detector task. */
public abstract Builder setBaseOptions(BaseOptions value); public abstract Builder setBaseOptions(BaseOptions value);
/** /**
* Sets the running mode for the object detector task. Default to the image mode. Object * Sets the {@link RunningMode} for the object detector task. Default to the image mode.
* detector has three modes: * Object detector has three modes:
* *
* <ul> * <ul>
* <li>IMAGE: The mode for detecting objects on single image inputs. * <li>IMAGE: The mode for detecting objects on single image inputs.
@ -293,8 +291,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public abstract Builder setRunningMode(RunningMode value); public abstract Builder setRunningMode(RunningMode value);
/** /**
* Sets the locale to use for display names specified through the TFLite Model Metadata, if * Sets the optional locale to use for display names specified through the TFLite Model
* any. Defaults to English. * Metadata, if any.
*/ */
public abstract Builder setDisplayNamesLocale(String value); public abstract Builder setDisplayNamesLocale(String value);
@ -331,12 +329,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
public abstract Builder setCategoryDenylist(List<String> value); public abstract Builder setCategoryDenylist(List<String> value);
/** /**
* Sets the result listener to receive the detection results asynchronously when the object * Sets the {@link ResultListener} to receive the detection results asynchronously when the
* detector is in the live stream mode. * object detector is in the live stream mode.
*/ */
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value); public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
/** Sets an optional error listener. */ /** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value); public abstract Builder setErrorListener(ErrorListener value);
abstract ObjectDetectorOptions autoBuild(); abstract ObjectDetectorOptions autoBuild();

View File

@ -11,7 +11,7 @@
android:targetSdkVersion="30" /> android:targetSdkVersion="30" />
<application <application
android:label="facedetectiontest" android:label="objectdetectortest"
android:name="android.support.multidex.MultiDexApplication" android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity=""> android:taskAffinity="">
<uses-library android:name="android.test.runner" /> <uses-library android:name="android.test.runner" />

View File

@ -1,45 +0,0 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test util for MediaPipe Tasks."""
import os
from absl import flags
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import image_frame as image_frame_module
FLAGS = flags.FLAGS
_Image = image_module.Image
_ImageFormat = image_frame_module.ImageFormat
_RGB_CHANNELS = 3
def test_srcdir():
"""Returns the path where to look for test data files."""
if "test_srcdir" in flags.FLAGS:
return flags.FLAGS["test_srcdir"].value
elif "TEST_SRCDIR" in os.environ:
return os.environ["TEST_SRCDIR"]
else:
raise RuntimeError("Missing TEST_SRCDIR environment.")
def get_test_data_path(file_or_dirname: str) -> str:
"""Returns full test data path."""
for (directory, subdirs, files) in os.walk(test_srcdir()):
for f in subdirs + files:
if f.endswith(file_or_dirname):
return os.path.join(directory, f)
raise ValueError("No %s in test directory" % file_or_dirname)

View File

@ -119,7 +119,7 @@ class ObjectDetectorTest(parameterized.TestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
r"ExternalFile must specify at least one of 'file_content', " r"ExternalFile must specify at least one of 'file_content', "
r"'file_name' or 'file_descriptor_meta'."): r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
base_options = _BaseOptions(model_asset_path='') base_options = _BaseOptions(model_asset_path='')
options = _ObjectDetectorOptions(base_options=base_options) options = _ObjectDetectorOptions(base_options=base_options)
_ObjectDetector.create_from_options(options) _ObjectDetector.create_from_options(options)

View File

@ -27,6 +27,7 @@ mediapipe_files(srcs = [
"albert_with_metadata.tflite", "albert_with_metadata.tflite",
"bert_text_classifier.tflite", "bert_text_classifier.tflite",
"mobilebert_with_metadata.tflite", "mobilebert_with_metadata.tflite",
"test_model_text_classifier_bool_output.tflite",
"test_model_text_classifier_with_regex_tokenizer.tflite", "test_model_text_classifier_with_regex_tokenizer.tflite",
]) ])

View File

@ -47,6 +47,7 @@ mediapipe_files(srcs = [
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"palm_detection_full.tflite", "palm_detection_full.tflite",
"pointing_up.jpg",
"right_hands.jpg", "right_hands.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
"segmentation_input_rotation0.jpg", "segmentation_input_rotation0.jpg",
@ -54,6 +55,7 @@ mediapipe_files(srcs = [
"selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_128_128_3_expected_mask.jpg",
"selfie_segm_144_256_3.tflite", "selfie_segm_144_256_3.tflite",
"selfie_segm_144_256_3_expected_mask.jpg", "selfie_segm_144_256_3_expected_mask.jpg",
"thumb_up.jpg",
]) ])
exports_files( exports_files(
@ -79,11 +81,13 @@ filegroup(
"left_hands.jpg", "left_hands.jpg",
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"pointing_up.jpg",
"right_hands.jpg", "right_hands.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
"segmentation_input_rotation0.jpg", "segmentation_input_rotation0.jpg",
"selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_128_128_3_expected_mask.jpg",
"selfie_segm_144_256_3_expected_mask.jpg", "selfie_segm_144_256_3_expected_mask.jpg",
"thumb_up.jpg",
], ],
visibility = [ visibility = [
"//mediapipe/python:__subpackages__", "//mediapipe/python:__subpackages__",

View File

@ -8,216 +8,216 @@ classifications {
landmarks { landmarks {
landmark { landmark {
x: 0.4749803 x: 0.47923622
y: 0.76872 y: 0.7426044
z: 9.286178e-08 z: 2.3221878e-07
} }
landmark { landmark {
x: 0.5466898 x: 0.5403745
y: 0.6706463 y: 0.66178805
z: -0.03454024 z: -0.044572093
} }
landmark { landmark {
x: 0.5890165 x: 0.5774534
y: 0.5604909 y: 0.5608346
z: -0.055142127 z: -0.07581605
} }
landmark { landmark {
x: 0.52780133 x: 0.52648556
y: 0.49855334 y: 0.50247055
z: -0.07846409 z: -0.105467044
} }
landmark { landmark {
x: 0.44487286 x: 0.44289914
y: 0.49801928 y: 0.49489295
z: -0.10188004 z: -0.13422011
} }
landmark { landmark {
x: 0.47572923 x: 0.4728853
y: 0.44477755 y: 0.43925008
z: -0.028345175 z: -0.058122505
} }
landmark { landmark {
x: 0.48013464 x: 0.4803168
y: 0.32467923 y: 0.32889345
z: -0.06513901 z: -0.101187326
} }
landmark { landmark {
x: 0.48351905 x: 0.48436823
y: 0.25804192 y: 0.25876504
z: -0.086756624 z: -0.12840955
} }
landmark { landmark {
x: 0.47760454 x: 0.47388697
y: 0.19289327 y: 0.19592366
z: -0.10468461 z: -0.15085006
} }
landmark { landmark {
x: 0.3993108 x: 0.39129356
y: 0.47566867 y: 0.47211456
z: -0.040357687 z: -0.06835801
} }
landmark { landmark {
x: 0.42361537 x: 0.41798547
y: 0.42491958 y: 0.42218646
z: -0.103545874 z: -0.12954563
} }
landmark { landmark {
x: 0.46059948 x: 0.45758423
y: 0.51723665 y: 0.5232461
z: -0.1214961 z: -0.14131334
} }
landmark { landmark {
x: 0.4580545 x: 0.45100626
y: 0.55640894 y: 0.5554065
z: -0.12272568 z: -0.13883406
} }
landmark { landmark {
x: 0.34109607 x: 0.33133638
y: 0.5184511 y: 0.51777464
z: -0.056422118 z: -0.08227023
} }
landmark { landmark {
x: 0.36177525 x: 0.35698116
y: 0.48427337 y: 0.48688585
z: -0.12584248 z: -0.14713185
} }
landmark { landmark {
x: 0.40706652 x: 0.40754414
y: 0.5700621 y: 0.57370347
z: -0.11658718 z: -0.12981415
} }
landmark { landmark {
x: 0.40535083 x: 0.40011865
y: 0.6000496 y: 0.5930706
z: -0.09520916 z: -0.10554546
} }
landmark { landmark {
x: 0.2872031 x: 0.2783401
y: 0.57303333 y: 0.5735568
z: -0.074813806 z: -0.09971398
} }
landmark { landmark {
x: 0.30961618 x: 0.30884498
y: 0.533245 y: 0.5394487
z: -0.114366606 z: -0.14033116
} }
landmark { landmark {
x: 0.35510173 x: 0.35470563
y: 0.5838698 y: 0.5917965
z: -0.096521005 z: -0.11820527
} }
landmark { landmark {
x: 0.36053744 x: 0.34865493
y: 0.608682 y: 0.61057556
z: -0.07574715 z: -0.09509217
} }
} }
world_landmarks { world_landmarks {
landmark { landmark {
x: 0.018890835 x: 0.016918864
y: 0.09005852 y: 0.08634466
z: 0.031907097 z: 0.035783045
} }
landmark { landmark {
x: 0.04198891 x: 0.04193685
y: 0.061256267 y: 0.056667875
z: 0.017695501 z: 0.019453367
} }
landmark { landmark {
x: 0.05044507 x: 0.050382353
y: 0.033841074 y: 0.031786427
z: 0.0015051212 z: 0.0023380776
} }
landmark { landmark {
x: 0.039822325 x: 0.043284662
y: 0.0073827556 y: 0.008976387
z: -0.02168335 z: -0.02496663
} }
landmark { landmark {
x: 0.012921701 x: 0.016010094
y: 0.0025111444 y: 0.004991216
z: -0.033813436 z: -0.036876947
} }
landmark { landmark {
x: 0.023851154 x: 0.02450771
y: -0.011495698 y: -0.013496464
z: 0.0066048754 z: 0.0041254223
} }
landmark { landmark {
x: 0.023206754 x: 0.024783865
y: -0.042496294 y: -0.041331705
z: -0.0026847485 z: -0.0028748964
} }
landmark { landmark {
x: 0.02298078 x: 0.025917178
y: -0.062678955 y: -0.06191107
z: -0.013068148 z: -0.010242647
} }
landmark { landmark {
x: 0.021972645 x: 0.023101516
y: -0.08151748 y: -0.07967696
z: -0.03677687 z: -0.03152665
} }
landmark { landmark {
x: -0.00016964211 x: 0.0006629339
y: -0.005549716 y: -0.0060150283
z: 0.0058569373 z: 0.004906766
} }
landmark { landmark {
x: 0.0075052455 x: 0.0077093104
y: -0.020031122 y: -0.017035034
z: -0.027775772 z: -0.029702934
} }
landmark { landmark {
x: 0.017835317 x: 0.017517095
y: 0.004899453 y: 0.008997183
z: -0.037390795 z: -0.03692814
} }
landmark { landmark {
x: 0.016913192 x: 0.0145079205
y: 0.018281722 y: 0.017461296
z: -0.019302163 z: -0.011290487
} }
landmark { landmark {
x: -0.018799124 x: -0.018095909
y: 0.0053577404 y: 0.006112392
z: -0.0040608873 z: -0.0027157406
} }
landmark { landmark {
x: -0.00747582 x: -0.010212201
y: 0.0019600953 y: 0.0052777785
z: -0.034023333 z: -0.034659054
} }
landmark { landmark {
x: 0.0035368819 x: 0.0043836404
y: 0.025736088 y: 0.028383566
z: -0.03452471 z: -0.03296758
} }
landmark { landmark {
x: 0.0080153765 x: 0.003886811
y: 0.039885145 y: 0.036054
z: -0.013341276 z: -0.0074628904
} }
landmark { landmark {
x: -0.029628165 x: -0.03178849
y: 0.028607829 y: 0.029854178
z: -0.011377414 z: -0.008874044
} }
landmark { landmark {
x: -0.023356002 x: -0.02403016
y: 0.017514031 y: 0.021497255
z: -0.029408533 z: -0.027618393
} }
landmark { landmark {
x: -0.008503268 x: -0.008522437
y: 0.027560957 y: 0.031886857
z: -0.035641473 z: -0.032367583
} }
landmark { landmark {
x: -0.0070180474 x: -0.012865841
y: 0.039056484 y: 0.038687646
z: -0.023629948 z: -0.017172804
} }
} }

View File

@ -8,216 +8,216 @@ classifications {
landmarks { landmarks {
landmark { landmark {
x: 0.6065784 x: 0.6387502
y: 0.7356081 y: 0.67134184
z: -5.2289305e-08 z: -3.4044612e-07
} }
landmark { landmark {
x: 0.6349347 x: 0.634891
y: 0.5735343 y: 0.53670025
z: -0.047243003 z: -0.06968865
} }
landmark { landmark {
x: 0.5788341 x: 0.5746676
y: 0.42688707 y: 0.41283816
z: -0.036071796 z: -0.09383486
} }
landmark { landmark {
x: 0.51322824 x: 0.49967948
y: 0.3153786 y: 0.32550922
z: -0.021018881 z: -0.10799447
} }
landmark { landmark {
x: 0.49179295 x: 0.47362617
y: 0.25291175 y: 0.25102285
z: 0.0061425082 z: -0.10590933
} }
landmark { landmark {
x: 0.49944243 x: 0.40749234
y: 0.45409226 y: 0.47130388
z: 0.06513325 z: -0.04694611
} }
landmark { landmark {
x: 0.3822241 x: 0.3372087
y: 0.45645967 y: 0.46742308
z: 0.045028925 z: -0.0997342
} }
landmark { landmark {
x: 0.4427338 x: 0.4418445
y: 0.49150866 y: 0.50960016
z: 0.024395633 z: -0.111206524
} }
landmark { landmark {
x: 0.5015556 x: 0.48056933
y: 0.4798539 y: 0.5187666
z: 0.014423937 z: -0.11022365
} }
landmark { landmark {
x: 0.46654877 x: 0.39218128
y: 0.5420721 y: 0.5495232
z: 0.08380699 z: -0.028925514
} }
landmark { landmark {
x: 0.3540949 x: 0.34047198
y: 0.545657 y: 0.55610204
z: 0.056201216 z: -0.08213869
} }
landmark { landmark {
x: 0.43828446 x: 0.46152583
y: 0.5723222 y: 0.58310646
z: 0.03073385 z: -0.08393028
} }
landmark { landmark {
x: 0.4894746 x: 0.47058716
y: 0.54662794 y: 0.56413835
z: 0.016284892 z: -0.078857616
} }
landmark { landmark {
x: 0.44287524 x: 0.39237642
y: 0.6153337 y: 0.61864823
z: 0.0878331 z: -0.022026168
} }
landmark { landmark {
x: 0.3531985 x: 0.34304678
y: 0.6305228 y: 0.62800515
z: 0.048528627 z: -0.08132204
} }
landmark { landmark {
x: 0.42727134 x: 0.45004016
y: 0.64344436 y: 0.64300805
z: 0.027383275 z: -0.06211204
} }
landmark { landmark {
x: 0.46999624 x: 0.4640005
y: 0.61115295 y: 0.6221539
z: 0.021795912 z: -0.038953774
} }
landmark { landmark {
x: 0.43323213 x: 0.39231628
y: 0.6734935 y: 0.68187976
z: 0.087731235 z: -0.020164328
} }
landmark { landmark {
x: 0.3772134 x: 0.35785866
y: 0.69590896 y: 0.6985842
z: 0.07259013 z: -0.052247807
} }
landmark { landmark {
x: 0.42301077 x: 0.42698768
y: 0.70083475 y: 0.69892275
z: 0.06279105 z: -0.037642766
} }
landmark { landmark {
x: 0.45672464 x: 0.44422707
y: 0.6844607 y: 0.6876204
z: 0.059202813 z: -0.02034688
} }
} }
world_landmarks { world_landmarks {
landmark { landmark {
x: 0.047059614 x: 0.06753889
y: 0.04719348 y: 0.031051591
z: 0.03951376 z: 0.05541924
} }
landmark { landmark {
x: 0.050449535 x: 0.06327636
y: 0.012183173 y: -0.003913434
z: 0.016567508 z: 0.02125023
} }
landmark { landmark {
x: 0.04375921 x: 0.05469646
y: -0.020305036 y: -0.038668767
z: 0.012189768 z: 0.01118496
} }
landmark { landmark {
x: 0.022525383 x: 0.03557241
y: -0.04830697 y: -0.06865983
z: 0.008714083 z: 0.0029562893
} }
landmark { landmark {
x: 0.011789754 x: 0.019069858
y: -0.06952699 y: -0.08740239
z: 0.0029319536 z: 0.007222481
} }
landmark { landmark {
x: 0.009532374 x: 0.0044852756
y: -0.019510617 y: -0.02772763
z: 0.0015609035 z: -0.004234833
} }
landmark { landmark {
x: -0.007894232 x: -0.0031203926
y: -0.022080563 y: -0.024173645
z: -0.014592148 z: -0.033932913
} }
landmark { landmark {
x: -0.002826123 x: 0.0080217365
y: -0.019949362 y: -0.018939625
z: -0.009392118 z: -0.032623816
} }
landmark { landmark {
x: 0.009066351 x: 0.025537387
y: -0.016403511 y: -0.014517117
z: 0.005516675 z: -0.004398854
} }
landmark { landmark {
x: -0.0031000748 x: -0.004470923
y: -0.003971943 y: -0.0040212176
z: 0.004851345 z: 0.0025033879
} }
landmark { landmark {
x: -0.016852753 x: -0.010845158
y: -0.009905987 y: -0.0031857258
z: -0.016275175 z: -0.036282137
} }
landmark { landmark {
x: -0.006703893 x: 0.016729971
y: -0.0026965735 y: 0.0028876318
z: -0.015606856 z: -0.036264844
} }
landmark { landmark {
x: 0.007890566 x: 0.019928008
y: -0.010418876 y: -0.0032422952
z: 0.0050479355 z: 0.004380459
} }
landmark { landmark {
x: -0.007842411 x: -0.005686749
y: 0.011552694 y: 0.017101247
z: -0.0005755241 z: 0.0036791638
} }
landmark { landmark {
x: -0.021125216 x: -0.010514952
y: 0.009268615 y: 0.017355483
z: -0.017993882 z: -0.02882688
} }
landmark { landmark {
x: -0.006585305 x: 0.014503509
y: 0.013378072 y: 0.019414417
z: -0.01709412 z: -0.026207235
} }
landmark { landmark {
x: 0.008140431 x: 0.0211232
y: 0.008364402 y: 0.014327417
z: -0.0051898304 z: 0.0011467658
} }
landmark { landmark {
x: -0.01082343 x: 0.0011399705
y: 0.03213215 y: 0.043651186
z: -0.00069864903 z: 0.0068390737
} }
landmark { landmark {
x: -0.0199164 x: -0.010388309
y: 0.028296603 y: 0.03904784
z: -0.01447433 z: -0.015677728
} }
landmark { landmark {
x: -0.00960456 x: 0.006957108
y: 0.026734762 y: 0.03613425
z: -0.019243335 z: -0.028704688
} }
landmark { landmark {
x: 0.0040425956 x: 0.012793289
y: 0.025051914 y: 0.03930679
z: -0.014775545 z: -0.012465539
} }
} }

View File

@ -432,8 +432,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_pointing_up_landmarks_pbtxt", name = "com_google_mediapipe_pointing_up_landmarks_pbtxt",
sha256 = "1255b6ba17b4ef7a9b3ce92c0a139e74fbcec272dc251b049b2f06732f9fed83", sha256 = "a3cd7f088a9e997dbb8f00d91dbf3faaacbdb262c8f2fde3c07a9d0656488065",
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1662650664573638"], urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1665174976408451"],
) )
http_file( http_file(
@ -562,6 +562,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"], urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
) )
http_file(
name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite",
sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310",
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"],
)
http_file( http_file(
name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite", name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f", sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",
@ -588,8 +594,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_thumb_up_landmarks_pbtxt", name = "com_google_mediapipe_thumb_up_landmarks_pbtxt",
sha256 = "bf1913df6ac7cc14b492c10411c827832839985c057b112789e04ce7c1fdd0fa", sha256 = "b129ae0536be4e25d6cdee74aabe9dedf1bcfe87430a40b68be4079db3a4d926",
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1662650669387278"], urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1665174979747784"],
) )
http_file( http_file(