Merge branch 'google:master' into image-classification-python-impl
This commit is contained in:
commit
a29035b91e
|
@ -320,6 +320,8 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -21,8 +21,10 @@
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
|||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
const char* input_tensor_buffer =
|
||||
input_tensor.GetCpuReadView().buffer<char>();
|
||||
tflite::DynamicBuffer dynamic_buffer;
|
||||
dynamic_buffer.AddString(input_tensor_buffer,
|
||||
input_tensor.shape().num_elements());
|
||||
dynamic_buffer.WriteToTensorAsVector(
|
||||
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||
int output_tensor_index,
|
||||
|
@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteUInt8: {
|
||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt8: {
|
||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt32: {
|
||||
|
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteString: {
|
||||
CopyTensorBufferToInterpreter<char>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteBool:
|
||||
// No current use-case for copying MediaPipe Tensors with bool type to
|
||||
// TfLiteTensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
||||
|
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteBool:
|
||||
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteString:
|
||||
// No current use-case for copying TfLiteTensors with string type to
|
||||
// MediaPipe Tensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported output tensor type:",
|
||||
|
|
|
@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
|
|||
case Tensor::ElementType::kInt8:
|
||||
Dequantize<int8>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
case Tensor::ElementType::kBool:
|
||||
Dequantize<bool>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Unsupported input tensor type: ", input_tensor.element_type()));
|
||||
|
|
|
@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
|
|||
ValidateResult(GetOutput(), {-1.007874, 0, 1});
|
||||
}
|
||||
|
||||
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
|
||||
std::vector<bool> tensor = {true, false, true};
|
||||
PushTensor(Tensor::ElementType::kBool, tensor,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
ValidateResult(GetOutput(), {1, 0, 1});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -1685,10 +1685,3 @@ cc_test(
|
|||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -112,6 +112,14 @@ class MultiPort : public Single {
|
|||
std::vector<std::unique_ptr<Base>>& vec_;
|
||||
};
|
||||
|
||||
namespace internal_builder {
|
||||
|
||||
template <typename T, typename U>
|
||||
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>>;
|
||||
|
||||
} // namespace internal_builder
|
||||
|
||||
// These classes wrap references to the underlying source/destination
|
||||
// endpoints, adding type information and the user-visible API.
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
|
@ -122,13 +130,14 @@ class DestinationImpl {
|
|||
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
|
||||
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
|
||||
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
|
||||
DestinationBase& base_;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort;
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
DestinationImpl<IsSide, U> Cast() {
|
||||
return DestinationImpl<IsSide, U>(&base_);
|
||||
}
|
||||
|
||||
DestinationBase& base_;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
|
@ -171,12 +180,8 @@ class SourceImpl {
|
|||
return AddTarget(dest);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
struct AllowCast
|
||||
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>> {};
|
||||
|
||||
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
SourceImpl<IsSide, U> Cast() {
|
||||
return SourceImpl<IsSide, U>(base_);
|
||||
}
|
||||
|
@ -186,12 +191,6 @@ class SourceImpl {
|
|||
SourceBase* base_;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
|
||||
};
|
||||
|
||||
// A source and a destination correspond to an output/input stream on a node,
|
||||
// and a side source and side destination correspond to an output/input side
|
||||
// packet.
|
||||
|
@ -201,20 +200,20 @@ class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
|
|||
template <typename T = internal::Generic>
|
||||
using Source = SourceImpl<false, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSource = MultiSourceImpl<false, T>;
|
||||
using MultiSource = MultiPort<Source<T>>;
|
||||
template <typename T = internal::Generic>
|
||||
using SideSource = SourceImpl<true, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSideSource = MultiSourceImpl<true, T>;
|
||||
using MultiSideSource = MultiPort<SideSource<T>>;
|
||||
|
||||
template <typename T = internal::Generic>
|
||||
using Destination = DestinationImpl<false, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using SideDestination = DestinationImpl<true, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiDestination = MultiDestinationImpl<false, T>;
|
||||
using MultiDestination = MultiPort<Destination<T>>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSideDestination = MultiDestinationImpl<true, T>;
|
||||
using MultiSideDestination = MultiPort<SideDestination<T>>;
|
||||
|
||||
class NodeBase {
|
||||
public:
|
||||
|
|
|
@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||
any_type_output.SetName("any_type_output");
|
||||
|
||||
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
|
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
output_stream: "ANY_OUTPUT:any_type_output"
|
||||
}
|
||||
input_stream: "GRAPH_ANY_INPUT:__stream_0"
|
||||
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
|
|
@ -334,13 +334,6 @@ mediapipe_register_type(
|
|||
deps = [":landmark_cc_proto"],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image",
|
||||
srcs = ["image.cc"],
|
||||
|
|
|
@ -33,10 +33,3 @@ mediapipe_proto_library(
|
|||
srcs = ["rasterization.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -97,8 +97,8 @@ class Tensor {
|
|||
kUInt8,
|
||||
kInt8,
|
||||
kInt32,
|
||||
// TODO: Update the inference runner to handle kTfLiteString.
|
||||
kChar
|
||||
kChar,
|
||||
kBool
|
||||
};
|
||||
struct Shape {
|
||||
Shape() = default;
|
||||
|
@ -330,6 +330,8 @@ class Tensor {
|
|||
return sizeof(int32_t);
|
||||
case ElementType::kChar:
|
||||
return sizeof(char);
|
||||
case ElementType::kBool:
|
||||
return sizeof(bool);
|
||||
}
|
||||
}
|
||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||
|
|
|
@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
|
|||
|
||||
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
||||
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
||||
|
||||
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
|
||||
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
||||
}
|
||||
|
||||
TEST(Cpu, TestMemoryAllocation) {
|
||||
|
|
|
@ -150,7 +150,7 @@ cc_library(
|
|||
name = "executor_util",
|
||||
srcs = ["executor_util.cc"],
|
||||
hdrs = ["executor_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
|
|
|
@ -1050,7 +1050,7 @@ objc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
MIN_IOS_VERSION = "9.0" # For thread_local.
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
test_suite(
|
||||
name = "ios",
|
||||
|
|
|
@ -184,7 +184,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
|||
EXPECT_THAT(
|
||||
audio_classifier_or.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name' or 'file_descriptor_meta'."));
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
|
|
|
@ -65,6 +65,8 @@ enum class MediaPipeTasksStatus {
|
|||
kFileReadError,
|
||||
// I/O error when mmap-ing file.
|
||||
kFileMmapError,
|
||||
// ZIP I/O error when unpacking the zip file.
|
||||
kFileZipError,
|
||||
|
||||
// TensorFlow Lite metadata error codes.
|
||||
|
||||
|
|
31
mediapipe/tasks/cc/components/containers/BUILD
Normal file
31
mediapipe/tasks/cc/components/containers/BUILD
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_detection",
|
||||
hdrs = ["landmarks_detection.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gesture_recognition_result",
|
||||
hdrs = ["gesture_recognition_result.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace containers {
|
||||
|
||||
// The gesture recognition result from GestureRecognizer, where each vector
|
||||
// element represents a single hand detected in the image.
|
||||
struct GestureRecognitionResult {
|
||||
// Recognized hand gestures with sorted order such that the winning label is
|
||||
// the first item in the list.
|
||||
std::vector<mediapipe::ClassificationList> gestures;
|
||||
// Classification of handedness.
|
||||
std::vector<mediapipe::ClassificationList> handedness;
|
||||
// Detected hand landmarks in normalized image coordinates.
|
||||
std::vector<mediapipe::NormalizedLandmarkList> hand_landmarks;
|
||||
// Detected hand landmarks in world coordinates.
|
||||
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
||||
};
|
||||
|
||||
} // namespace containers
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_
|
|
@ -0,0 +1,43 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
// Sturcts holding landmarks related data structure for hand landmarker, pose
|
||||
// detector, face mesher, etc.
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// x and y are in [0,1] range with origin in top left in input image space.
|
||||
// If model provides z, z is in the same scale as x. origin is in the center
|
||||
// of the face.
|
||||
struct Landmark {
|
||||
float x;
|
||||
float y;
|
||||
float z;
|
||||
};
|
||||
|
||||
// [0, 1] range in input image space
|
||||
struct Bound {
|
||||
float left;
|
||||
float top;
|
||||
float right;
|
||||
float bottom;
|
||||
};
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_outer_classname = "CategoryProto";
|
||||
|
||||
// A single classification result.
|
||||
message Category {
|
||||
// The index of the category in the corresponding label map, usually packed in
|
||||
|
|
|
@ -19,6 +19,9 @@ package mediapipe.tasks.components.containers.proto;
|
|||
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.container.proto";
|
||||
option java_outer_classname = "ClassificationsProto";
|
||||
|
||||
// List of predicted categories with an optional timestamp.
|
||||
message ClassificationEntry {
|
||||
// The array of predicted categories, usually sorted by descending scores,
|
||||
|
|
|
@ -123,15 +123,17 @@ absl::StatusOr<ClassificationHeadsProperties> GetClassificationHeadsProperties(
|
|||
const auto* tensor =
|
||||
primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i));
|
||||
if (tensor->type() != tflite::TensorType_FLOAT32 &&
|
||||
tensor->type() != tflite::TensorType_UINT8) {
|
||||
tensor->type() != tflite::TensorType_UINT8 &&
|
||||
tensor->type() != tflite::TensorType_BOOL) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrFormat("Expected output tensor at index %d to have type "
|
||||
"UINT8 or FLOAT32, found %s instead.",
|
||||
"UINT8 or FLOAT32 or BOOL, found %s instead.",
|
||||
i, tflite::EnumNameTensorType(tensor->type())),
|
||||
MediaPipeTasksStatus::kInvalidOutputTensorTypeError);
|
||||
}
|
||||
if (tensor->type() == tflite::TensorType_UINT8) {
|
||||
if (tensor->type() == tflite::TensorType_UINT8 ||
|
||||
tensor->type() == tflite::TensorType_BOOL) {
|
||||
num_quantized_tensors++;
|
||||
}
|
||||
}
|
||||
|
@ -282,6 +284,20 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ConfigureClassificationAggregationCalculator(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
ClassificationAggregationCalculatorOptions* options) {
|
||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (const auto& metadata : *output_tensors_metadata) {
|
||||
options->add_head_names(metadata->name()->str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
||||
// classifier options and the (optional) output tensor metadata.
|
||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||
|
@ -333,20 +349,6 @@ absl::Status ConfigureTensorsToClassificationCalculator(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void ConfigureClassificationAggregationCalculator(
|
||||
const ModelMetadataExtractor& metadata_extractor,
|
||||
ClassificationAggregationCalculatorOptions* options) {
|
||||
auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata();
|
||||
if (output_tensors_metadata == nullptr) {
|
||||
return;
|
||||
}
|
||||
for (const auto& metadata : *output_tensors_metadata) {
|
||||
options->add_head_names(metadata->name()->str());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const ModelResources& model_resources,
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -55,6 +56,16 @@ absl::Status ConfigureClassificationPostprocessingGraph(
|
|||
const proto::ClassifierOptions& classifier_options,
|
||||
proto::ClassificationPostprocessingGraphOptions* options);
|
||||
|
||||
// Utility function to fill in the TensorsToClassificationCalculatorOptions
|
||||
// based on the classifier options and the (optional) output tensor metadata.
|
||||
// This is meant to be used by other graphs that may also rely on this
|
||||
// calculator.
|
||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||
const proto::ClassifierOptions& options,
|
||||
const metadata::ModelMetadataExtractor& metadata_extractor,
|
||||
int tensor_index,
|
||||
mediapipe::TensorsToClassificationCalculatorOptions* calculator_options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
|
|
|
@ -17,6 +17,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.components.processors.proto";
|
||||
option java_outer_classname = "ClassifierOptionsProto";
|
||||
|
||||
// Shared options used by all classification tasks.
|
||||
message ClassifierOptions {
|
||||
// The locale to use for display names specified through the TFLite Model
|
||||
|
|
|
@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions {
|
|||
BERT_PREPROCESSOR = 1;
|
||||
// Used for the RegexPreprocessorCalculator.
|
||||
REGEX_PREPROCESSOR = 2;
|
||||
// Used for the TextToTensorCalculator.
|
||||
STRING_PREPROCESSOR = 3;
|
||||
}
|
||||
optional PreprocessorType preprocessor_type = 1;
|
||||
|
||||
|
|
|
@ -65,6 +65,8 @@ absl::StatusOr<std::string> GetCalculatorNameFromPreprocessorType(
|
|||
return "BertPreprocessorCalculator";
|
||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR:
|
||||
return "RegexPreprocessorCalculator";
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR:
|
||||
return "TextToTensorCalculator";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) {
|
|||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||
}
|
||||
if (all_string_tensors) {
|
||||
// TODO: Support a TextToTensor calculator for string tensors.
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"String tensors are not supported yet",
|
||||
MediaPipeTasksStatus::kInvalidInputTensorTypeError);
|
||||
return TextPreprocessingGraphOptions::STRING_PREPROCESSOR;
|
||||
}
|
||||
|
||||
// Otherwise, all tensors should have type int32
|
||||
|
@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph(
|
|||
TextPreprocessingGraphOptions::PreprocessorType preprocessor_type,
|
||||
GetPreprocessorType(model_resources));
|
||||
options.set_preprocessor_type(preprocessor_type);
|
||||
ASSIGN_OR_RETURN(
|
||||
int max_seq_len,
|
||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||
options.set_max_seq_len(max_seq_len);
|
||||
switch (preprocessor_type) {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||
break;
|
||||
}
|
||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: {
|
||||
ASSIGN_OR_RETURN(
|
||||
int max_seq_len,
|
||||
GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0]));
|
||||
options.set_max_seq_len(max_seq_len);
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph {
|
|||
GetCalculatorNameFromPreprocessorType(options.preprocessor_type()));
|
||||
auto& text_preprocessor = graph.AddNode(preprocessor_name);
|
||||
switch (options.preprocessor_type()) {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: {
|
||||
case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR:
|
||||
case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: {
|
||||
break;
|
||||
}
|
||||
case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: {
|
||||
|
|
|
@ -92,13 +92,26 @@ absl::Status ExternalFileHandler::MapExternalFile() {
|
|||
#else
|
||||
if (!external_file_.file_content().empty()) {
|
||||
return absl::OkStatus();
|
||||
} else if (external_file_.has_file_pointer_meta()) {
|
||||
if (external_file_.file_pointer_meta().pointer() == 0) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kInvalidArgument,
|
||||
"Need to set the file pointer in external_file.file_pointer_meta.");
|
||||
}
|
||||
if (external_file_.file_pointer_meta().length() <= 0) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kInvalidArgument,
|
||||
"The length of the file in external_file.file_pointer_meta should be "
|
||||
"positive.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
if (external_file_.file_name().empty() &&
|
||||
!external_file_.has_file_descriptor_meta()) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kInvalidArgument,
|
||||
"ExternalFile must specify at least one of 'file_content', 'file_name' "
|
||||
"or 'file_descriptor_meta'.",
|
||||
"ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
// Obtain file descriptor, offset and size.
|
||||
|
@ -196,6 +209,11 @@ absl::Status ExternalFileHandler::MapExternalFile() {
|
|||
absl::string_view ExternalFileHandler::GetFileContent() {
|
||||
if (!external_file_.file_content().empty()) {
|
||||
return external_file_.file_content();
|
||||
} else if (external_file_.has_file_pointer_meta()) {
|
||||
void* ptr =
|
||||
reinterpret_cast<void*>(external_file_.file_pointer_meta().pointer());
|
||||
return absl::string_view(static_cast<const char*>(ptr),
|
||||
external_file_.file_pointer_meta().length());
|
||||
} else {
|
||||
return absl::string_view(static_cast<const char*>(buffer_) +
|
||||
buffer_offset_ - buffer_aligned_offset_,
|
||||
|
|
|
@ -26,10 +26,11 @@ option java_outer_classname = "ExternalFileProto";
|
|||
// (1) file contents loaded in `file_content`.
|
||||
// (2) file path in `file_name`.
|
||||
// (3) file descriptor through `file_descriptor_meta` as returned by open(2).
|
||||
// (4) file pointer and length in memory through `file_pointer_meta`.
|
||||
//
|
||||
// If more than one field of these fields is provided, they are used in this
|
||||
// precedence order.
|
||||
// Next id: 4
|
||||
// Next id: 5
|
||||
message ExternalFile {
|
||||
// The file contents as a byte array.
|
||||
optional bytes file_content = 1;
|
||||
|
@ -40,6 +41,13 @@ message ExternalFile {
|
|||
// The file descriptor to a file opened with open(2), with optional additional
|
||||
// offset and length information.
|
||||
optional FileDescriptorMeta file_descriptor_meta = 3;
|
||||
|
||||
// The pointer points to location of a file in memory. Use the util method,
|
||||
// `SetExternalFile` in [1], to configure `file_pointer_meta` from a
|
||||
// `std::string_view` object.
|
||||
//
|
||||
// [1]: mediapipe/tasks/cc/metadata/utils/zip_utils.h
|
||||
optional FilePointerMeta file_pointer_meta = 4;
|
||||
}
|
||||
|
||||
// A proto defining file descriptor metadata for mapping file into memory using
|
||||
|
@ -62,3 +70,14 @@ message FileDescriptorMeta {
|
|||
// offset of a given asset obtained from AssetFileDescriptor#getStartOffset().
|
||||
optional int64 offset = 3;
|
||||
}
|
||||
|
||||
// The pointer points to location of a file in memory. Make sure the file memory
|
||||
// that it points locates on the same machine and it outlives this
|
||||
// FilePointerMeta object.
|
||||
message FilePointerMeta {
|
||||
// Memory address of the file in decimal.
|
||||
optional uint64 pointer = 1;
|
||||
|
||||
// File length.
|
||||
optional int64 length = 2;
|
||||
}
|
||||
|
|
|
@ -19,8 +19,9 @@ cc_library(
|
|||
deps = [
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_readonly_mem_file",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
|
@ -29,7 +30,6 @@ cc_library(
|
|||
"@com_google_absl//absl/strings:str_format",
|
||||
"@flatbuffers//:runtime_cc",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
"@zlib//:zlib_minizip",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -17,16 +17,16 @@ limitations under the License.
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "contrib/minizip/ioapi.h"
|
||||
#include "contrib/minizip/unzip.h"
|
||||
#include "flatbuffers/flatbuffers.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
|
@ -53,72 +53,6 @@ const T* GetItemFromVector(
|
|||
}
|
||||
return src_vector->Get(index);
|
||||
}
|
||||
|
||||
// Wrapper function around calls to unzip to avoid repeating conversion logic
|
||||
// from error code to Status.
|
||||
absl::Status UnzipErrorToStatus(int error) {
|
||||
if (error != UNZ_OK) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kUnknown, "Unable to read associated file in zip archive.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Stores a file name, position in zip buffer and size.
|
||||
struct ZipFileInfo {
|
||||
std::string name;
|
||||
ZPOS64_T position;
|
||||
ZPOS64_T size;
|
||||
};
|
||||
|
||||
// Returns the ZipFileInfo corresponding to the current file in the provided
|
||||
// unzFile object.
|
||||
absl::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
|
||||
// Open file in raw mode, as data is expected to be uncompressed.
|
||||
int method;
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(
|
||||
unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
|
||||
if (method != Z_NO_COMPRESSION) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kUnknown, "Expected uncompressed zip archive.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
|
||||
}
|
||||
|
||||
// Get file info a first time to get filename size.
|
||||
unz_file_info64 file_info;
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
|
||||
zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
|
||||
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
|
||||
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
|
||||
|
||||
// Second call to get file name.
|
||||
auto file_name_size = file_info.size_filename;
|
||||
char* c_file_name = (char*)malloc(file_name_size);
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
|
||||
zf, &file_info, c_file_name, file_name_size,
|
||||
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
|
||||
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
|
||||
std::string file_name = std::string(c_file_name, file_name_size);
|
||||
free(c_file_name);
|
||||
|
||||
// Get position in file.
|
||||
auto position = unzGetCurrentFileZStreamPos64(zf);
|
||||
if (position == 0) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kUnknown, "Unable to read file in zip archive.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
|
||||
}
|
||||
|
||||
// Close file and return.
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
|
||||
|
||||
ZipFileInfo result{};
|
||||
result.name = file_name;
|
||||
result.position = position;
|
||||
result.size = file_info.uncompressed_size;
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/* static */
|
||||
|
@ -238,47 +172,15 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer(
|
|||
|
||||
absl::Status ModelMetadataExtractor::ExtractAssociatedFiles(
|
||||
const char* buffer_data, size_t buffer_size) {
|
||||
// Create in-memory read-only zip file.
|
||||
ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
|
||||
// Open zip.
|
||||
unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def());
|
||||
if (zf == nullptr) {
|
||||
auto status =
|
||||
ExtractFilesfromZipFile(buffer_data, buffer_size, &associated_files_);
|
||||
if (!status.ok() &&
|
||||
absl::StrContains(status.message(), "Unable to open zip archive.")) {
|
||||
// It's OK if it fails: this means there are no associated files with this
|
||||
// model.
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// Get number of files.
|
||||
unz_global_info global_info;
|
||||
if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kUnknown, "Unable to get zip archive info.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
|
||||
}
|
||||
|
||||
// Browse through files in archive.
|
||||
if (global_info.number_entry > 0) {
|
||||
int error = unzGoToFirstFile(zf);
|
||||
while (error == UNZ_OK) {
|
||||
ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
|
||||
// Store result in map.
|
||||
associated_files_[zip_file_info.name] = absl::string_view(
|
||||
buffer_data + zip_file_info.position, zip_file_info.size);
|
||||
error = unzGoToNextFile(zf);
|
||||
}
|
||||
if (error != UNZ_END_OF_LIST_OF_FILE) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kUnknown,
|
||||
"Unable to read associated file in zip archive.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
|
||||
}
|
||||
}
|
||||
// Close zip.
|
||||
if (unzClose(zf) != UNZ_OK) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kUnknown, "Unable to close zip archive.",
|
||||
MediaPipeTasksStatus::kMetadataAssociatedFileZipError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
return status;
|
||||
}
|
||||
|
||||
absl::StatusOr<absl::string_view> ModelMetadataExtractor::GetAssociatedFile(
|
||||
|
|
|
@ -24,3 +24,20 @@ cc_library(
|
|||
"@zlib//:zlib_minizip",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "zip_utils",
|
||||
srcs = ["zip_utils.cc"],
|
||||
hdrs = ["zip_utils.h"],
|
||||
deps = [
|
||||
":zip_readonly_mem_file",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||
"@com_google_absl//absl/cleanup",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@zlib//:zlib_minizip",
|
||||
],
|
||||
)
|
||||
|
|
175
mediapipe/tasks/cc/metadata/utils/zip_utils.cc
Normal file
175
mediapipe/tasks/cc/metadata/utils/zip_utils.cc
Normal file
|
@ -0,0 +1,175 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "contrib/minizip/ioapi.h"
|
||||
#include "contrib/minizip/unzip.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace metadata {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::absl::StatusCode;
|
||||
|
||||
// Wrapper function around calls to unzip to avoid repeating conversion logic
|
||||
// from error code to Status.
|
||||
absl::Status UnzipErrorToStatus(int error) {
|
||||
if (error != UNZ_OK) {
|
||||
return CreateStatusWithPayload(StatusCode::kUnknown,
|
||||
"Unable to read the file in zip archive.",
|
||||
MediaPipeTasksStatus::kFileZipError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Stores a file name, position in zip buffer and size.
|
||||
struct ZipFileInfo {
|
||||
std::string name;
|
||||
ZPOS64_T position;
|
||||
ZPOS64_T size;
|
||||
};
|
||||
|
||||
// Returns the ZipFileInfo corresponding to the current file in the provided
|
||||
// unzFile object.
|
||||
absl::StatusOr<ZipFileInfo> GetCurrentZipFileInfo(const unzFile& zf) {
|
||||
// Open file in raw mode, as data is expected to be uncompressed.
|
||||
int method;
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(
|
||||
unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1)));
|
||||
absl::Cleanup unzipper_closer = [zf]() {
|
||||
auto status = UnzipErrorToStatus(unzCloseCurrentFile(zf));
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to close the current zip file: " << status;
|
||||
}
|
||||
};
|
||||
if (method != Z_NO_COMPRESSION) {
|
||||
return CreateStatusWithPayload(StatusCode::kUnknown,
|
||||
"Expected uncompressed zip archive.",
|
||||
MediaPipeTasksStatus::kFileZipError);
|
||||
}
|
||||
|
||||
// Get file info a first time to get filename size.
|
||||
unz_file_info64 file_info;
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
|
||||
zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0,
|
||||
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
|
||||
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
|
||||
|
||||
// Second call to get file name.
|
||||
auto file_name_size = file_info.size_filename;
|
||||
char* c_file_name = (char*)malloc(file_name_size);
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64(
|
||||
zf, &file_info, c_file_name, file_name_size,
|
||||
/*extraField=*/nullptr, /*extraFieldBufferSize=*/0,
|
||||
/*szComment=*/nullptr, /*szCommentBufferSize=*/0)));
|
||||
std::string file_name = std::string(c_file_name, file_name_size);
|
||||
free(c_file_name);
|
||||
|
||||
// Get position in file.
|
||||
auto position = unzGetCurrentFileZStreamPos64(zf);
|
||||
if (position == 0) {
|
||||
return CreateStatusWithPayload(StatusCode::kUnknown,
|
||||
"Unable to read file in zip archive.",
|
||||
MediaPipeTasksStatus::kFileZipError);
|
||||
}
|
||||
|
||||
// Perform the cleanup manually for error propagation.
|
||||
std::move(unzipper_closer).Cancel();
|
||||
// Close file and return.
|
||||
MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf)));
|
||||
|
||||
ZipFileInfo result{};
|
||||
result.name = file_name;
|
||||
result.position = position;
|
||||
result.size = file_info.uncompressed_size;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status ExtractFilesfromZipFile(
|
||||
const char* buffer_data, const size_t buffer_size,
|
||||
absl::flat_hash_map<std::string, absl::string_view>* files) {
|
||||
// Create in-memory read-only zip file.
|
||||
ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size);
|
||||
// Open zip.
|
||||
unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def());
|
||||
if (zf == nullptr) {
|
||||
return CreateStatusWithPayload(StatusCode::kUnknown,
|
||||
"Unable to open zip archive.",
|
||||
MediaPipeTasksStatus::kFileZipError);
|
||||
}
|
||||
absl::Cleanup unzipper_closer = [zf]() {
|
||||
if (unzClose(zf) != UNZ_OK) {
|
||||
LOG(ERROR) << "Unable to close zip archive.";
|
||||
}
|
||||
};
|
||||
// Get number of files.
|
||||
unz_global_info global_info;
|
||||
if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) {
|
||||
return CreateStatusWithPayload(StatusCode::kUnknown,
|
||||
"Unable to get zip archive info.",
|
||||
MediaPipeTasksStatus::kFileZipError);
|
||||
}
|
||||
|
||||
// Browse through files in archive.
|
||||
if (global_info.number_entry > 0) {
|
||||
int error = unzGoToFirstFile(zf);
|
||||
while (error == UNZ_OK) {
|
||||
ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf));
|
||||
// Store result in map.
|
||||
(*files)[zip_file_info.name] = absl::string_view(
|
||||
buffer_data + zip_file_info.position, zip_file_info.size);
|
||||
error = unzGoToNextFile(zf);
|
||||
}
|
||||
if (error != UNZ_END_OF_LIST_OF_FILE) {
|
||||
return CreateStatusWithPayload(
|
||||
StatusCode::kUnknown,
|
||||
"Unable to read associated file in zip archive.",
|
||||
MediaPipeTasksStatus::kFileZipError);
|
||||
}
|
||||
}
|
||||
// Perform the cleanup manually for error propagation.
|
||||
std::move(unzipper_closer).Cancel();
|
||||
// Close zip.
|
||||
if (unzClose(zf) != UNZ_OK) {
|
||||
return CreateStatusWithPayload(StatusCode::kUnknown,
|
||||
"Unable to close zip archive.",
|
||||
MediaPipeTasksStatus::kFileZipError);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void SetExternalFile(const std::string_view& file_content,
|
||||
core::proto::ExternalFile* model_file) {
|
||||
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
|
||||
|
||||
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
|
||||
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
47
mediapipe/tasks/cc/metadata/utils/zip_utils.h
Normal file
47
mediapipe/tasks/cc/metadata/utils/zip_utils.h
Normal file
|
@ -0,0 +1,47 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace metadata {
|
||||
|
||||
// Extract files from the zip file.
|
||||
// Input: Pointer and length of the zip file in memory.
|
||||
// Outputs: A map with the filename as key and a pointer to the file contents
|
||||
// as value. The file contents returned by this function are only guaranteed to
|
||||
// stay valid while buffer_data is alive.
|
||||
absl::Status ExtractFilesfromZipFile(
|
||||
const char* buffer_data, const size_t buffer_size,
|
||||
absl::flat_hash_map<std::string, absl::string_view>* files);
|
||||
|
||||
// Set file_pointer_meta in ExternalFile which is the pointer points to location
|
||||
// of a file in memory by file_content.
|
||||
void SetExternalFile(const std::string_view& file_content,
|
||||
core::proto::ExternalFile* model_file);
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_
|
|
@ -44,8 +44,12 @@ cc_library(
|
|||
name = "hand_gesture_recognizer_graph",
|
||||
srcs = ["hand_gesture_recognizer_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||
"//mediapipe/calculators/core:concatenate_vector_calculator",
|
||||
"//mediapipe/calculators/core:end_loop_calculator",
|
||||
"//mediapipe/calculators/core:get_vector_item_calculator",
|
||||
"//mediapipe/calculators/tensor:tensor_converter_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
|
@ -55,7 +59,6 @@ cc_library(
|
|||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
|
@ -67,10 +70,81 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/utils:image_tensor_specs",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gesture_recognizer_graph",
|
||||
srcs = ["gesture_recognizer_graph.cc"],
|
||||
deps = [
|
||||
":hand_gesture_recognizer_graph",
|
||||
"//mediapipe/calculators/core:vector_indices_calculator",
|
||||
"//mediapipe/calculators/image:image_properties_calculator",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gesture_recognizer",
|
||||
srcs = ["gesture_recognizer.cc"],
|
||||
hdrs = ["gesture_recognizer.h"],
|
||||
deps = [
|
||||
":gesture_recognizer_graph",
|
||||
":hand_gesture_recognizer_graph",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,282 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h"
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
|
||||
namespace {
|
||||
|
||||
using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
gesture_recognizer::proto::GestureRecognizerGraphOptions;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::GestureRecognitionResult;
|
||||
|
||||
constexpr char kHandGestureSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
|
||||
constexpr char kHandGesturesStreamName[] = "hand_gestures";
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
constexpr char kHandednessStreamName[] = "handedness";
|
||||
constexpr char kHandLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kHandLandmarksStreamName[] = "landmarks";
|
||||
constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running
|
||||
// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the
|
||||
// number of frames in flight.
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<GestureRecognizerGraphOptionsProto> options,
|
||||
bool enable_flow_limiting) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName);
|
||||
subgraph.GetOptions<GestureRecognizerGraphOptionsProto>().Swap(options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >>
|
||||
graph.Out(kHandGesturesTag);
|
||||
subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >>
|
||||
graph.Out(kHandednessTag);
|
||||
subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >>
|
||||
graph.Out(kHandLandmarksTag);
|
||||
subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >>
|
||||
graph.Out(kHandWorldLandmarksTag);
|
||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(graph, subgraph, {kImageTag},
|
||||
kHandGesturesTag);
|
||||
}
|
||||
graph.In(kImageTag) >> subgraph.In(kImageTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing GestureRecognizerOptions struct to the internal
|
||||
// GestureRecognizerGraphOptions proto.
|
||||
std::unique_ptr<GestureRecognizerGraphOptionsProto>
|
||||
ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) {
|
||||
auto options_proto = std::make_unique<GestureRecognizerGraphOptionsProto>();
|
||||
|
||||
bool use_stream_mode = options->running_mode != core::RunningMode::IMAGE;
|
||||
|
||||
// TODO remove these workarounds for base options of subgraphs.
|
||||
// Configure hand detector options.
|
||||
auto base_options_proto_for_hand_detector =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(options->base_options_for_hand_detector)));
|
||||
base_options_proto_for_hand_detector->set_use_stream_mode(use_stream_mode);
|
||||
auto* hand_detector_graph_options =
|
||||
options_proto->mutable_hand_landmarker_graph_options()
|
||||
->mutable_hand_detector_graph_options();
|
||||
hand_detector_graph_options->mutable_base_options()->Swap(
|
||||
base_options_proto_for_hand_detector.get());
|
||||
hand_detector_graph_options->set_num_hands(options->num_hands);
|
||||
hand_detector_graph_options->set_min_detection_confidence(
|
||||
options->min_hand_detection_confidence);
|
||||
|
||||
// Configure hand landmark detector options.
|
||||
auto base_options_proto_for_hand_landmarker =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(options->base_options_for_hand_landmarker)));
|
||||
base_options_proto_for_hand_landmarker->set_use_stream_mode(use_stream_mode);
|
||||
auto* hand_landmarks_detector_graph_options =
|
||||
options_proto->mutable_hand_landmarker_graph_options()
|
||||
->mutable_hand_landmarks_detector_graph_options();
|
||||
hand_landmarks_detector_graph_options->mutable_base_options()->Swap(
|
||||
base_options_proto_for_hand_landmarker.get());
|
||||
hand_landmarks_detector_graph_options->set_min_detection_confidence(
|
||||
options->min_hand_presence_confidence);
|
||||
|
||||
auto* hand_landmarker_graph_options =
|
||||
options_proto->mutable_hand_landmarker_graph_options();
|
||||
hand_landmarker_graph_options->set_min_tracking_confidence(
|
||||
options->min_tracking_confidence);
|
||||
|
||||
// Configure hand gesture recognizer options.
|
||||
auto base_options_proto_for_gesture_recognizer =
|
||||
std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(
|
||||
&(options->base_options_for_gesture_recognizer)));
|
||||
base_options_proto_for_gesture_recognizer->set_use_stream_mode(
|
||||
use_stream_mode);
|
||||
auto* hand_gesture_recognizer_graph_options =
|
||||
options_proto->mutable_hand_gesture_recognizer_graph_options();
|
||||
hand_gesture_recognizer_graph_options->mutable_base_options()->Swap(
|
||||
base_options_proto_for_gesture_recognizer.get());
|
||||
if (options->min_gesture_confidence >= 0) {
|
||||
hand_gesture_recognizer_graph_options->mutable_classifier_options()
|
||||
->set_score_threshold(options->min_gesture_confidence);
|
||||
}
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
|
||||
std::unique_ptr<GestureRecognizerOptions> options) {
|
||||
auto options_proto = ConvertGestureRecognizerGraphOptionsProto(options.get());
|
||||
tasks::core::PacketsCallback packets_callback = nullptr;
|
||||
if (options->result_callback) {
|
||||
auto result_callback = options->result_callback;
|
||||
packets_callback = [=](absl::StatusOr<tasks::core::PacketMap>
|
||||
status_or_packets) {
|
||||
if (!status_or_packets.ok()) {
|
||||
Image image;
|
||||
result_callback(status_or_packets.status(), image,
|
||||
Timestamp::Unset().Value());
|
||||
return;
|
||||
}
|
||||
if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
|
||||
return;
|
||||
}
|
||||
Packet gesture_packet =
|
||||
status_or_packets.value()[kHandGesturesStreamName];
|
||||
Packet handedness_packet =
|
||||
status_or_packets.value()[kHandednessStreamName];
|
||||
Packet hand_landmarks_packet =
|
||||
status_or_packets.value()[kHandLandmarksStreamName];
|
||||
Packet hand_world_landmarks_packet =
|
||||
status_or_packets.value()[kHandWorldLandmarksStreamName];
|
||||
Packet image_packet = status_or_packets.value()[kImageOutStreamName];
|
||||
result_callback(
|
||||
{{gesture_packet.Get<std::vector<ClassificationList>>(),
|
||||
handedness_packet.Get<std::vector<ClassificationList>>(),
|
||||
hand_landmarks_packet.Get<std::vector<NormalizedLandmarkList>>(),
|
||||
hand_world_landmarks_packet.Get<std::vector<LandmarkList>>()}},
|
||||
image_packet.Get<Image>(),
|
||||
gesture_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||
};
|
||||
}
|
||||
return core::VisionTaskApiFactory::Create<GestureRecognizer,
|
||||
GestureRecognizerGraphOptionsProto>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
||||
mediapipe::Image image) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto output_packets,
|
||||
ProcessImageData({{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))}}));
|
||||
return {
|
||||
{/* gestures= */ {output_packets[kHandGesturesStreamName]
|
||||
.Get<std::vector<ClassificationList>>()},
|
||||
/* handedness= */
|
||||
{output_packets[kHandednessStreamName]
|
||||
.Get<std::vector<mediapipe::ClassificationList>>()},
|
||||
/* hand_landmarks= */
|
||||
{output_packets[kHandLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
|
||||
/* hand_world_landmarks */
|
||||
{output_packets[kHandWorldLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::LandmarkList>>()}},
|
||||
};
|
||||
}
|
||||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
return {
|
||||
{/* gestures= */ {output_packets[kHandGesturesStreamName]
|
||||
.Get<std::vector<ClassificationList>>()},
|
||||
/* handedness= */
|
||||
{output_packets[kHandednessStreamName]
|
||||
.Get<std::vector<mediapipe::ClassificationList>>()},
|
||||
/* hand_landmarks= */
|
||||
{output_packets[kHandLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::NormalizedLandmarkList>>()},
|
||||
/* hand_world_landmarks */
|
||||
{output_packets[kHandWorldLandmarksStreamName]
|
||||
.Get<std::vector<mediapipe::LandmarkList>>()}},
|
||||
};
|
||||
}
|
||||
|
||||
absl::Status GestureRecognizer::RecognizeAsync(mediapipe::Image image,
|
||||
int64 timestamp_ms) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,172 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
|
||||
struct GestureRecognizerOptions {
|
||||
// Base options for configuring Task library, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// TODO: remove these. Temporary solutions before bundle asset is
|
||||
// ready.
|
||||
tasks::core::BaseOptions base_options_for_hand_landmarker;
|
||||
tasks::core::BaseOptions base_options_for_hand_detector;
|
||||
tasks::core::BaseOptions base_options_for_gesture_recognizer;
|
||||
|
||||
// The running mode of the task. Default to the image mode.
|
||||
// GestureRecognizer has three running modes:
|
||||
// 1) The image mode for recognizing hand gestures on single image inputs.
|
||||
// 2) The video mode for recognizing hand gestures on the decoded frames of a
|
||||
// video.
|
||||
// 3) The live stream mode for recognizing hand gestures on the live stream of
|
||||
// input data, such as from camera. In this mode, the "result_callback"
|
||||
// below must be specified to receive the detection results asynchronously.
|
||||
core::RunningMode running_mode = core::RunningMode::IMAGE;
|
||||
|
||||
// The maximum number of hands can be detected by the GestureRecognizer.
|
||||
int num_hands = 1;
|
||||
|
||||
// The minimum confidence score for the hand detection to be considered
|
||||
// successfully.
|
||||
float min_hand_detection_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score of hand presence score in the hand landmark
|
||||
// detection.
|
||||
float min_hand_presence_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score for the hand tracking to be considered
|
||||
// successfully.
|
||||
float min_tracking_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score for the gestures to be considered
|
||||
// successfully. If < 0, the gesture confidence thresholds in the model
|
||||
// metadata are used.
|
||||
// TODO Note this option is subject to change, after scoring
|
||||
// merging calculator is implemented.
|
||||
float min_gesture_confidence = -1;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult>,
|
||||
const Image&, int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
// Performs hand gesture recognition on the given image.
|
||||
//
|
||||
// TODO add the link to DevSite.
|
||||
// This API expects expects a pre-trained hand gesture model asset bundle, or a
|
||||
// custom one created using Model Maker. See <link to the DevSite documentation
|
||||
// page>.
|
||||
//
|
||||
// Inputs:
|
||||
// Image
|
||||
// - The image that gesture recognition runs on.
|
||||
// Outputs:
|
||||
// GestureRecognitionResult
|
||||
// - The hand gesture recognition results.
|
||||
class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
using BaseVisionTaskApi::BaseVisionTaskApi;
|
||||
|
||||
// Creates a GestureRecognizer from a GestureRecognizerhOptions to process
|
||||
// image data or streaming data. Gesture recognizer can be created with one of
|
||||
// the following three running modes:
|
||||
// 1) Image mode for recognizing gestures on single image inputs.
|
||||
// Users provide mediapipe::Image to the `Recognize` method, and will
|
||||
// receive the recognized hand gesture results as the return value.
|
||||
// 2) Video mode for recognizing gestures on the decoded frames of a video.
|
||||
// 3) Live stream mode for recognizing gestures on the live stream of the
|
||||
// input data, such as from camera. Users call `RecognizeAsync` to push the
|
||||
// image data into the GestureRecognizer, the recognized results along with
|
||||
// the input timestamp and the image that gesture recognizer runs on will
|
||||
// be available in the result callback when the gesture recognizer finishes
|
||||
// the work.
|
||||
static absl::StatusOr<std::unique_ptr<GestureRecognizer>> Create(
|
||||
std::unique_ptr<GestureRecognizerOptions> options);
|
||||
|
||||
// Performs hand gesture recognition on the given image.
|
||||
// Only use this method when the GestureRecognizer is created with the image
|
||||
// running mode.
|
||||
//
|
||||
// image - mediapipe::Image
|
||||
// Image to perform hand gesture recognition on.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
|
||||
Image image);
|
||||
|
||||
// Performs gesture recognition on the provided video frame.
|
||||
// Only use this method when the GestureRecognizer is created with the video
|
||||
// running mode.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult>
|
||||
RecognizeForVideo(Image image, int64 timestamp_ms);
|
||||
|
||||
// Sends live image data to perform gesture recognition, and the results will
|
||||
// be available via the "result_callback" provided in the
|
||||
// GestureRecognizerOptions. Only use this method when the GestureRecognizer
|
||||
// is created with the live stream running mode.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
// sent to the gesture recognizer. The input timestamps must be monotonically
|
||||
// increasing.
|
||||
//
|
||||
// The "result_callback" provides
|
||||
// - A vector of GestureRecognitionResult, each is the recognized results
|
||||
// for a input frame.
|
||||
// - The const reference to the corresponding input image that the gesture
|
||||
// recognizer runs on. Note that the const reference to the image will no
|
||||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status RecognizeAsync(Image image, int64 timestamp_ms);
|
||||
|
||||
// Shuts down the GestureRecognizer when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_
|
|
@ -0,0 +1,211 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace gesture_recognizer {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
GestureRecognizerGraphOptions;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
HandGestureRecognizerGraphOptions;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
HandLandmarkerGraphOptions;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
|
||||
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
|
||||
|
||||
struct GestureRecognizerOutputs {
|
||||
Source<std::vector<ClassificationList>> gesture;
|
||||
Source<std::vector<ClassificationList>> handedness;
|
||||
Source<std::vector<NormalizedLandmarkList>> hand_landmarks;
|
||||
Source<std::vector<LandmarkList>> hand_world_landmarks;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs
|
||||
// hand gesture recognition.
|
||||
//
|
||||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform hand gesture recognition on.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - std::vector<ClassificationList>
|
||||
// Recognized hand gestures with sorted order such that the winning label is
|
||||
// the first item in the list.
|
||||
// LANDMARKS: - std::vector<NormalizedLandmarkList>
|
||||
// Detected hand landmarks.
|
||||
// WORLD_LANDMARKS - std::vector<LandmarkList>
|
||||
// Detected hand landmarks in world coordinates.
|
||||
// HAND_RECT_NEXT_FRAME - std::vector<NormalizedRect>
|
||||
// The predicted Rect enclosing the hand RoI for landmark detection on the
|
||||
// next frame.
|
||||
// HANDEDNESS - std::vector<ClassificationList>
|
||||
// Classification of handedness.
|
||||
// IMAGE - mediapipe::Image
|
||||
// The image that gesture recognizer runs on and has the pixel data stored
|
||||
// on the target storage (CPU vs GPU).
|
||||
//
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// output_stream: "HAND_GESTURES:hand_gestures"
|
||||
// output_stream: "LANDMARKS:hand_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
|
||||
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
|
||||
// output_stream: "HANDEDNESS:handedness"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.gesture_recognizer.proto.GestureRecognizerGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "hand_gesture.tflite"
|
||||
// }
|
||||
// }
|
||||
// hand_landmark_detector_options {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "hand_landmark.tflite"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class GestureRecognizerGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
|
||||
BuildGestureRecognizerGraph(
|
||||
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
hand_gesture_recognition_output.gesture >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
|
||||
hand_gesture_recognition_output.handedness >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandednessTag)];
|
||||
hand_gesture_recognition_output.hand_landmarks >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
hand_gesture_recognition_output.hand_world_landmarks >>
|
||||
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
|
||||
hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<GestureRecognizerOutputs> BuildGestureRecognizerGraph(
|
||||
GestureRecognizerGraphOptions& graph_options, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
auto& image_property = graph.AddNode("ImagePropertiesCalculator");
|
||||
image_in >> image_property.In("IMAGE");
|
||||
auto image_size = image_property.Out("SIZE");
|
||||
|
||||
// Hand landmarker graph.
|
||||
auto& hand_landmarker_graph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
|
||||
auto& hand_landmarker_graph_options =
|
||||
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
|
||||
hand_landmarker_graph_options.Swap(
|
||||
graph_options.mutable_hand_landmarker_graph_options());
|
||||
|
||||
image_in >> hand_landmarker_graph.In(kImageTag);
|
||||
auto hand_landmarks =
|
||||
hand_landmarker_graph[Output<std::vector<NormalizedLandmarkList>>(
|
||||
kLandmarksTag)];
|
||||
auto hand_world_landmarks =
|
||||
hand_landmarker_graph[Output<std::vector<LandmarkList>>(
|
||||
kWorldLandmarksTag)];
|
||||
auto handedness =
|
||||
hand_landmarker_graph[Output<std::vector<ClassificationList>>(
|
||||
kHandednessTag)];
|
||||
|
||||
auto& vector_indices =
|
||||
graph.AddNode("NormalizedLandmarkListVectorIndicesCalculator");
|
||||
hand_landmarks >> vector_indices.In("VECTOR");
|
||||
auto hand_landmarks_id = vector_indices.Out("INDICES");
|
||||
|
||||
// Hand gesture recognizer subgraph.
|
||||
auto& hand_gesture_subgraph = graph.AddNode(
|
||||
"mediapipe.tasks.vision.gesture_recognizer."
|
||||
"MultipleHandGestureRecognizerGraph");
|
||||
hand_gesture_subgraph.GetOptions<HandGestureRecognizerGraphOptions>().Swap(
|
||||
graph_options.mutable_hand_gesture_recognizer_graph_options());
|
||||
hand_landmarks >> hand_gesture_subgraph.In(kLandmarksTag);
|
||||
hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag);
|
||||
handedness >> hand_gesture_subgraph.In(kHandednessTag);
|
||||
image_size >> hand_gesture_subgraph.In(kImageSizeTag);
|
||||
hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag);
|
||||
auto hand_gestures =
|
||||
hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
|
||||
kHandGesturesTag)];
|
||||
|
||||
return {{.gesture = hand_gestures,
|
||||
.handedness = handedness,
|
||||
.hand_landmarks = hand_landmarks,
|
||||
.hand_world_landmarks = hand_world_landmarks,
|
||||
.image = hand_landmarker_graph[Output<Image>(kImageTag)]}};
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::gesture_recognizer::GestureRecognizerGraph); // NOLINT
|
||||
// clang-format on
|
||||
|
||||
} // namespace gesture_recognizer
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -27,7 +27,6 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -36,7 +35,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -50,7 +48,8 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::processors::
|
||||
ConfigureTensorsToClassificationCalculator;
|
||||
using ::mediapipe::tasks::vision::gesture_recognizer::proto::
|
||||
HandGestureRecognizerGraphOptions;
|
||||
|
||||
|
@ -95,15 +94,14 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
// The size of image from which the landmarks detected from.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - ClassificationResult
|
||||
// HAND_GESTURES - ClassificationList
|
||||
// Recognized hand gestures with sorted order such that the winning label is
|
||||
// the first item in the list.
|
||||
//
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph"
|
||||
// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerGraph"
|
||||
// input_stream: "HANDEDNESS:handedness"
|
||||
// input_stream: "LANDMARKS:landmarks"
|
||||
// input_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||
|
@ -136,12 +134,12 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
|
||||
graph[Input<LandmarkList>(kWorldLandmarksTag)],
|
||||
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
|
||||
hand_gestures >> graph[Output<ClassificationResult>(kHandGesturesTag)];
|
||||
hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildGestureRecognizerGraph(
|
||||
absl::StatusOr<Source<ClassificationList>> BuildGestureRecognizerGraph(
|
||||
const HandGestureRecognizerGraphOptions& graph_options,
|
||||
const core::ModelResources& model_resources,
|
||||
Source<ClassificationList> handedness,
|
||||
|
@ -201,25 +199,24 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
auto concatenated_tensors = concatenate_tensor_vector.Out("");
|
||||
|
||||
// Inference for static hand gesture recognition.
|
||||
// TODO add embedding step.
|
||||
auto& inference = AddInference(
|
||||
model_resources, graph_options.base_options().acceleration(), graph);
|
||||
concatenated_tensors >> inference.In(kTensorsTag);
|
||||
auto inference_output_tensors = inference.Out(kTensorsTag);
|
||||
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, graph_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference_output_tensors >> postprocessing.In(kTensorsTag);
|
||||
auto classification_result =
|
||||
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
|
||||
|
||||
return classification_result;
|
||||
auto& tensors_to_classification =
|
||||
graph.AddNode("TensorsToClassificationCalculator");
|
||||
MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator(
|
||||
graph_options.classifier_options(),
|
||||
*model_resources.GetMetadataExtractor(), 0,
|
||||
&tensors_to_classification.GetOptions<
|
||||
mediapipe::TensorsToClassificationCalculatorOptions>()));
|
||||
inference_output_tensors >> tensors_to_classification.In(kTensorsTag);
|
||||
auto classification_list =
|
||||
tensors_to_classification[Output<ClassificationList>(
|
||||
"CLASSIFICATIONS")];
|
||||
return classification_list;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -247,9 +244,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// index corresponding to the same hand if the graph runs multiple times.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - std::vector<ClassificationResult>
|
||||
// HAND_GESTURES - std::vector<ClassificationList>
|
||||
// A vector of recognized hand gestures. Each vector element is the
|
||||
// ClassificationResult of the hand in input vector.
|
||||
// ClassificationList of the hand in input vector.
|
||||
//
|
||||
//
|
||||
// Example:
|
||||
|
@ -288,12 +285,12 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
graph[Input<std::pair<int, int>>(kImageSizeTag)],
|
||||
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
|
||||
multi_hand_gestures >>
|
||||
graph[Output<std::vector<ClassificationResult>>(kHandGesturesTag)];
|
||||
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<Source<std::vector<ClassificationResult>>>
|
||||
absl::StatusOr<Source<std::vector<ClassificationList>>>
|
||||
BuildMultiGestureRecognizerSubraph(
|
||||
const HandGestureRecognizerGraphOptions& graph_options,
|
||||
Source<std::vector<ClassificationList>> multi_handedness,
|
||||
|
@ -346,12 +343,13 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
|
||||
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
|
||||
|
||||
auto& end_loop_classification_results =
|
||||
graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator");
|
||||
batch_end >> end_loop_classification_results.In(kBatchEndTag);
|
||||
hand_gestures >> end_loop_classification_results.In(kItemTag);
|
||||
auto multi_hand_gestures = end_loop_classification_results
|
||||
[Output<std::vector<ClassificationResult>>(kIterableTag)];
|
||||
auto& end_loop_classification_lists =
|
||||
graph.AddNode("EndLoopClassificationListCalculator");
|
||||
batch_end >> end_loop_classification_lists.In(kBatchEndTag);
|
||||
hand_gestures >> end_loop_classification_lists.In(kItemTag);
|
||||
auto multi_hand_gestures =
|
||||
end_loop_classification_lists[Output<std::vector<ClassificationList>>(
|
||||
kIterableTag)];
|
||||
|
||||
return multi_hand_gestures;
|
||||
}
|
||||
|
|
|
@ -31,8 +31,8 @@ mediapipe_proto_library(
|
|||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_gesture_recognizer_graph_options_proto",
|
||||
srcs = ["hand_gesture_recognizer_graph_options.proto"],
|
||||
name = "gesture_classifier_graph_options_proto",
|
||||
srcs = ["gesture_classifier_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
@ -40,3 +40,28 @@ mediapipe_proto_library(
|
|||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "hand_gesture_recognizer_graph_options_proto",
|
||||
srcs = ["hand_gesture_recognizer_graph_options.proto"],
|
||||
deps = [
|
||||
":gesture_classifier_graph_options_proto",
|
||||
":gesture_embedder_graph_options_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "gesture_recognizer_graph_options_proto",
|
||||
srcs = ["gesture_recognizer_graph_options.proto"],
|
||||
deps = [
|
||||
":hand_gesture_recognizer_graph_options_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message GestureClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional GestureClassifierGraphOptions ext = 478825465;
|
||||
}
|
||||
// Base options for configuring hand gesture recognition subgraph, such as
|
||||
// specifying the TfLite model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.vision.gesture_recognizer.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer";
|
||||
option java_outer_classname = "GestureRecognizerGraphOptionsProto";
|
||||
|
||||
message GestureRecognizerGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional GestureRecognizerGraphOptions ext = 479097054;
|
||||
}
|
||||
// Base options for configuring gesture recognizer graph, such as specifying
|
||||
// the TfLite model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for configuring hand landmarker graph.
|
||||
optional hand_landmarker.proto.HandLandmarkerGraphOptions
|
||||
hand_landmarker_graph_options = 2;
|
||||
|
||||
// Options for configuring hand gesture recognizer graph.
|
||||
optional HandGestureRecognizerGraphOptions
|
||||
hand_gesture_recognizer_graph_options = 3;
|
||||
}
|
|
@ -20,6 +20,8 @@ package mediapipe.tasks.vision.gesture_recognizer.proto;
|
|||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto";
|
||||
import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto";
|
||||
|
||||
message HandGestureRecognizerGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
|
@ -29,11 +31,18 @@ message HandGestureRecognizerGraphOptions {
|
|||
// specifying the TfLite model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for configuring the gesture classifier behavior, such as score
|
||||
// threshold, number of results, etc.
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
// Options for GestureEmbedder.
|
||||
optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2;
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
|
||||
// considered tracked successfully
|
||||
optional float min_tracking_confidence = 3 [default = 0.0];
|
||||
// Options for GestureClassifier of default gestures.
|
||||
optional GestureClassifierGraphOptions
|
||||
canned_gesture_classifier_graph_options = 3;
|
||||
|
||||
// Options for GestureClassifier of custom gestures.
|
||||
optional GestureClassifierGraphOptions
|
||||
custom_gesture_classifier_graph_options = 4;
|
||||
|
||||
// TODO: remove these. Temporary solutions before bundle asset is
|
||||
// ready.
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 5;
|
||||
}
|
||||
|
|
|
@ -80,6 +80,7 @@ cc_library(
|
|||
"//mediapipe/calculators/core:gate_calculator_cc_proto",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/calculators/core:previous_loopback_calculator",
|
||||
"//mediapipe/calculators/image:image_properties_calculator",
|
||||
"//mediapipe/calculators/util:collection_has_min_size_calculator",
|
||||
"//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
|
@ -98,6 +99,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_landmarks_deduplication_calculator",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto",
|
||||
],
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe/app/xeno:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
])
|
||||
|
||||
|
@ -46,4 +45,26 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO: Enable this test
|
||||
cc_library(
|
||||
name = "hand_landmarks_deduplication_calculator",
|
||||
srcs = ["hand_landmarks_deduplication_calculator.cc"],
|
||||
hdrs = ["hand_landmarks_deduplication_calculator.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:landmarks_detection",
|
||||
"//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder",
|
||||
"//mediapipe/tasks/cc/vision/utils:landmarks_utils",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
|
||||
|
||||
namespace mediapipe::api2 {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::Bound;
|
||||
using ::mediapipe::tasks::vision::utils::CalculateIOU;
|
||||
using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
|
||||
|
||||
float Distance(const NormalizedLandmark& lm_a, const NormalizedLandmark& lm_b,
|
||||
int width, int height) {
|
||||
return std::sqrt(std::pow((lm_a.x() - lm_b.x()) * width, 2) +
|
||||
std::pow((lm_a.y() - lm_b.y()) * height, 2));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<float>> Distances(const NormalizedLandmarkList& a,
|
||||
const NormalizedLandmarkList& b,
|
||||
int width, int height) {
|
||||
const int num = a.landmark_size();
|
||||
RET_CHECK_EQ(b.landmark_size(), num);
|
||||
std::vector<float> distances;
|
||||
distances.reserve(num);
|
||||
for (int i = 0; i < num; ++i) {
|
||||
const NormalizedLandmark& lm_a = a.landmark(i);
|
||||
const NormalizedLandmark& lm_b = b.landmark(i);
|
||||
distances.push_back(Distance(lm_a, lm_b, width, height));
|
||||
}
|
||||
return distances;
|
||||
}
|
||||
|
||||
// Calculates a baseline distance of a hand that can be used as a relative
|
||||
// measure when calculating hand to hand similarity.
|
||||
//
|
||||
// Calculated as maximum of distances: 0->5, 5->17, 17->0, where 0, 5, 17 key
|
||||
// points are depicted below:
|
||||
//
|
||||
// /Middle/
|
||||
// |
|
||||
// /Index/ | /Ring/
|
||||
// | | | /Pinky/
|
||||
// V V V |
|
||||
// V
|
||||
// [8] [12] [16]
|
||||
// | | | [20]
|
||||
// | | | |
|
||||
// /Thumb/ | | | |
|
||||
// | [7] [11] [15] [19]
|
||||
// V | | | |
|
||||
// | | | |
|
||||
// [4] | | | |
|
||||
// | [6] [10] [14] [18]
|
||||
// | | | | |
|
||||
// | | | | |
|
||||
// [3] | | | |
|
||||
// | [5]----[9]---[13]---[17]
|
||||
// . | |
|
||||
// \ . |
|
||||
// \ / |
|
||||
// [2] |
|
||||
// \ |
|
||||
// \ |
|
||||
// \ |
|
||||
// [1] .
|
||||
// \ /
|
||||
// \ /
|
||||
// ._____[0]_____.
|
||||
//
|
||||
// ^
|
||||
// |
|
||||
// /Wrist/
|
||||
absl::StatusOr<float> HandBaselineDistance(
|
||||
const NormalizedLandmarkList& landmarks, int width, int height) {
|
||||
RET_CHECK_EQ(landmarks.landmark_size(), 21); // Num of hand landmarks.
|
||||
constexpr int kWrist = 0;
|
||||
constexpr int kIndexFingerMcp = 5;
|
||||
constexpr int kPinkyMcp = 17;
|
||||
float distance = Distance(landmarks.landmark(kWrist),
|
||||
landmarks.landmark(kIndexFingerMcp), width, height);
|
||||
distance = std::max(distance,
|
||||
Distance(landmarks.landmark(kIndexFingerMcp),
|
||||
landmarks.landmark(kPinkyMcp), width, height));
|
||||
distance =
|
||||
std::max(distance, Distance(landmarks.landmark(kPinkyMcp),
|
||||
landmarks.landmark(kWrist), width, height));
|
||||
return distance;
|
||||
}
|
||||
|
||||
Bound CalculateBound(const NormalizedLandmarkList& list) {
|
||||
constexpr float kMinInitialValue = std::numeric_limits<float>::max();
|
||||
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
|
||||
|
||||
// Compute min and max values on landmarks (they will form
|
||||
// bounding box)
|
||||
float bounding_box_left = kMinInitialValue;
|
||||
float bounding_box_top = kMinInitialValue;
|
||||
float bounding_box_right = kMaxInitialValue;
|
||||
float bounding_box_bottom = kMaxInitialValue;
|
||||
for (const auto& landmark : list.landmark()) {
|
||||
bounding_box_left = std::min(bounding_box_left, landmark.x());
|
||||
bounding_box_top = std::min(bounding_box_top, landmark.y());
|
||||
bounding_box_right = std::max(bounding_box_right, landmark.x());
|
||||
bounding_box_bottom = std::max(bounding_box_bottom, landmark.y());
|
||||
}
|
||||
|
||||
// Populate normalized non rotated face bounding box
|
||||
return {.left = bounding_box_left,
|
||||
.top = bounding_box_top,
|
||||
.right = bounding_box_right,
|
||||
.bottom = bounding_box_bottom};
|
||||
}
|
||||
|
||||
// Uses IoU and distance of some corresponding hand landmarks to detect
|
||||
// duplicate / similar hands. IoU, distance thresholds, number of landmarks to
|
||||
// match are found experimentally. Evaluated:
|
||||
// - manually comparing side by side, before and after deduplication applied
|
||||
// - generating gesture dataset, and checking select frames in baseline and
|
||||
// "deduplicated" dataset
|
||||
// - by confirming gesture training is better with use of deduplication using
|
||||
// selected thresholds
|
||||
class HandDuplicatesFinder : public DuplicatesFinder {
|
||||
public:
|
||||
explicit HandDuplicatesFinder(bool start_from_the_end)
|
||||
: start_from_the_end_(start_from_the_end) {}
|
||||
|
||||
absl::StatusOr<absl::flat_hash_set<int>> FindDuplicates(
|
||||
const std::vector<NormalizedLandmarkList>& multi_landmarks,
|
||||
int input_width, int input_height) override {
|
||||
absl::flat_hash_set<int> retained_indices;
|
||||
absl::flat_hash_set<int> suppressed_indices;
|
||||
|
||||
const int num = multi_landmarks.size();
|
||||
std::vector<float> baseline_distances;
|
||||
baseline_distances.reserve(num);
|
||||
std::vector<Bound> bounds;
|
||||
bounds.reserve(num);
|
||||
for (const NormalizedLandmarkList& list : multi_landmarks) {
|
||||
ASSIGN_OR_RETURN(const float baseline_distance,
|
||||
HandBaselineDistance(list, input_width, input_height));
|
||||
baseline_distances.push_back(baseline_distance);
|
||||
bounds.push_back(CalculateBound(list));
|
||||
}
|
||||
|
||||
for (int index = 0; index < num; ++index) {
|
||||
const int i = start_from_the_end_ ? num - index - 1 : index;
|
||||
const float stable_distance_i = baseline_distances[i];
|
||||
bool suppressed = false;
|
||||
for (int j : retained_indices) {
|
||||
const float stable_distance_j = baseline_distances[j];
|
||||
|
||||
constexpr float kAllowedBaselineDistanceRatio = 0.2f;
|
||||
const float distance_threshold =
|
||||
std::max(stable_distance_i, stable_distance_j) *
|
||||
kAllowedBaselineDistanceRatio;
|
||||
|
||||
ASSIGN_OR_RETURN(const std::vector<float> distances,
|
||||
Distances(multi_landmarks[i], multi_landmarks[j],
|
||||
input_width, input_height));
|
||||
const int num_matched_landmarks = absl::c_count_if(
|
||||
distances,
|
||||
[&](float distance) { return distance < distance_threshold; });
|
||||
|
||||
const float iou = CalculateIOU(bounds[i], bounds[j]);
|
||||
|
||||
constexpr int kNumMatchedLandmarksToSuppressHand = 10; // out of 21
|
||||
constexpr float kMinIouThresholdToSuppressHand = 0.2f;
|
||||
if (num_matched_landmarks >= kNumMatchedLandmarksToSuppressHand &&
|
||||
iou > kMinIouThresholdToSuppressHand) {
|
||||
suppressed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (suppressed) {
|
||||
suppressed_indices.insert(i);
|
||||
} else {
|
||||
retained_indices.insert(i);
|
||||
}
|
||||
}
|
||||
return suppressed_indices;
|
||||
}
|
||||
|
||||
private:
|
||||
const bool start_from_the_end_;
|
||||
};
|
||||
|
||||
template <typename InputPortT>
|
||||
absl::StatusOr<absl::optional<typename InputPortT::PayloadT>>
|
||||
VerifyNumAndMaybeInitOutput(const InputPortT& port, CalculatorContext* cc,
|
||||
int num_expected_size) {
|
||||
absl::optional<typename InputPortT::PayloadT> output;
|
||||
if (port(cc).IsConnected() && !port(cc).IsEmpty()) {
|
||||
RET_CHECK_EQ(port(cc).Get().size(), num_expected_size);
|
||||
typename InputPortT::PayloadT result;
|
||||
return {{result}};
|
||||
}
|
||||
return {absl::nullopt};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<DuplicatesFinder> CreateHandDuplicatesFinder(
|
||||
bool start_from_the_end) {
|
||||
return absl::make_unique<HandDuplicatesFinder>(start_from_the_end);
|
||||
}
|
||||
|
||||
absl::Status HandLandmarksDeduplicationCalculator::Process(
|
||||
mediapipe::CalculatorContext* cc) {
|
||||
if (kInLandmarks(cc).IsEmpty()) return absl::OkStatus();
|
||||
if (kInSize(cc).IsEmpty()) return absl::OkStatus();
|
||||
|
||||
const std::vector<NormalizedLandmarkList>& in_landmarks = *kInLandmarks(cc);
|
||||
const std::pair<int, int>& image_size = *kInSize(cc);
|
||||
|
||||
std::unique_ptr<DuplicatesFinder> duplicates_finder =
|
||||
CreateHandDuplicatesFinder(/*start_from_the_end=*/false);
|
||||
ASSIGN_OR_RETURN(absl::flat_hash_set<int> indices_to_remove,
|
||||
duplicates_finder->FindDuplicates(
|
||||
in_landmarks, image_size.first, image_size.second));
|
||||
|
||||
if (indices_to_remove.empty()) {
|
||||
kOutLandmarks(cc).Send(kInLandmarks(cc));
|
||||
kOutRois(cc).Send(kInRois(cc));
|
||||
kOutWorldLandmarks(cc).Send(kInWorldLandmarks(cc));
|
||||
kOutClassifications(cc).Send(kInClassifications(cc));
|
||||
} else {
|
||||
std::vector<NormalizedLandmarkList> out_landmarks;
|
||||
const int num = in_landmarks.size();
|
||||
|
||||
ASSIGN_OR_RETURN(absl::optional<std::vector<NormalizedRect>> out_rois,
|
||||
VerifyNumAndMaybeInitOutput(kInRois, cc, num));
|
||||
ASSIGN_OR_RETURN(
|
||||
absl::optional<std::vector<LandmarkList>> out_world_landmarks,
|
||||
VerifyNumAndMaybeInitOutput(kInWorldLandmarks, cc, num));
|
||||
ASSIGN_OR_RETURN(
|
||||
absl::optional<std::vector<ClassificationList>> out_classifications,
|
||||
VerifyNumAndMaybeInitOutput(kInClassifications, cc, num));
|
||||
|
||||
for (int i = 0; i < num; ++i) {
|
||||
if (indices_to_remove.find(i) != indices_to_remove.end()) continue;
|
||||
|
||||
out_landmarks.push_back(in_landmarks[i]);
|
||||
if (out_rois) {
|
||||
out_rois->push_back(kInRois(cc).Get()[i]);
|
||||
}
|
||||
if (out_world_landmarks) {
|
||||
out_world_landmarks->push_back(kInWorldLandmarks(cc).Get()[i]);
|
||||
}
|
||||
if (out_classifications) {
|
||||
out_classifications->push_back(kInClassifications(cc).Get()[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (!out_landmarks.empty()) {
|
||||
kOutLandmarks(cc).Send(std::move(out_landmarks));
|
||||
}
|
||||
if (out_rois && !out_rois->empty()) {
|
||||
kOutRois(cc).Send(std::move(out_rois.value()));
|
||||
}
|
||||
if (out_world_landmarks && !out_world_landmarks->empty()) {
|
||||
kOutWorldLandmarks(cc).Send(std::move(out_world_landmarks.value()));
|
||||
}
|
||||
if (out_classifications && !out_classifications->empty()) {
|
||||
kOutClassifications(cc).Send(std::move(out_classifications.value()));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
MEDIAPIPE_REGISTER_NODE(HandLandmarksDeduplicationCalculator);
|
||||
|
||||
} // namespace mediapipe::api2
|
|
@ -0,0 +1,97 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h"
|
||||
|
||||
namespace mediapipe::api2 {
|
||||
|
||||
// Create a DuplicatesFinder dedicated for finding hand duplications.
|
||||
std::unique_ptr<tasks::vision::utils::DuplicatesFinder>
|
||||
CreateHandDuplicatesFinder(bool start_from_the_end = false);
|
||||
|
||||
// Filter duplicate hand landmarks by finding the overlapped hands.
|
||||
// Inputs:
|
||||
// MULTI_LANDMARKS - std::vector<NormalizedLandmarkList>
|
||||
// The hand landmarks to be filtered.
|
||||
// MULTI_ROIS - std::vector<NormalizedRect>
|
||||
// The regions where each encloses the landmarks of a single hand.
|
||||
// MULTI_WORLD_LANDMARKS - std::vector<LandmarkList>
|
||||
// The hand landmarks to be filtered in world coordinates.
|
||||
// MULTI_CLASSIFICATIONS - std::vector<ClassificationList>
|
||||
// The handedness of hands.
|
||||
// IMAGE_SIZE - std::pair<int, int>
|
||||
// The size of the image which the hand landmarks are detected on.
|
||||
//
|
||||
// Outputs:
|
||||
// MULTI_LANDMARKS - std::vector<NormalizedLandmarkList>
|
||||
// The hand landmarks with duplication removed.
|
||||
// MULTI_ROIS - std::vector<NormalizedRect>
|
||||
// The regions where each encloses the landmarks of a single hand with
|
||||
// duplicate hands removed.
|
||||
// MULTI_WORLD_LANDMARKS - std::vector<LandmarkList>
|
||||
// The hand landmarks with duplication removed in world coordinates.
|
||||
// MULTI_CLASSIFICATIONS - std::vector<ClassificationList>
|
||||
// The handedness of hands with duplicate hands removed.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "HandLandmarksDeduplicationCalculator"
|
||||
// input_stream: "MULTI_LANDMARKS:landmarks_in"
|
||||
// input_stream: "MULTI_ROIS:rois_in"
|
||||
// input_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_in"
|
||||
// input_stream: "MULTI_CLASSIFICATIONS:handedness_in"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "MULTI_LANDMARKS:landmarks_out"
|
||||
// output_stream: "MULTI_ROIS:rois_out"
|
||||
// output_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_out"
|
||||
// output_stream: "MULTI_CLASSIFICATIONS:handedness_out"
|
||||
// }
|
||||
class HandLandmarksDeduplicationCalculator : public Node {
|
||||
public:
|
||||
constexpr static Input<std::vector<mediapipe::NormalizedLandmarkList>>
|
||||
kInLandmarks{"MULTI_LANDMARKS"};
|
||||
constexpr static Input<std::vector<mediapipe::NormalizedRect>>::Optional
|
||||
kInRois{"MULTI_ROIS"};
|
||||
constexpr static Input<std::vector<mediapipe::LandmarkList>>::Optional
|
||||
kInWorldLandmarks{"MULTI_WORLD_LANDMARKS"};
|
||||
constexpr static Input<std::vector<mediapipe::ClassificationList>>::Optional
|
||||
kInClassifications{"MULTI_CLASSIFICATIONS"};
|
||||
constexpr static Input<std::pair<int, int>> kInSize{"IMAGE_SIZE"};
|
||||
|
||||
constexpr static Output<std::vector<mediapipe::NormalizedLandmarkList>>
|
||||
kOutLandmarks{"MULTI_LANDMARKS"};
|
||||
constexpr static Output<std::vector<mediapipe::NormalizedRect>>::Optional
|
||||
kOutRois{"MULTI_ROIS"};
|
||||
constexpr static Output<std::vector<mediapipe::LandmarkList>>::Optional
|
||||
kOutWorldLandmarks{"MULTI_WORLD_LANDMARKS"};
|
||||
constexpr static Output<std::vector<mediapipe::ClassificationList>>::Optional
|
||||
kOutClassifications{"MULTI_CLASSIFICATIONS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kInLandmarks, kInRois, kInWorldLandmarks,
|
||||
kInClassifications, kInSize, kOutLandmarks, kOutRois,
|
||||
kOutWorldLandmarks, kOutClassifications);
|
||||
absl::Status Process(mediapipe::CalculatorContext* cc) override;
|
||||
};
|
||||
|
||||
} // namespace mediapipe::api2
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_
|
|
@ -247,11 +247,37 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
|||
image_in >> hand_landmarks_detector_graph.In("IMAGE");
|
||||
clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT");
|
||||
|
||||
auto landmarks = hand_landmarks_detector_graph.Out(kLandmarksTag);
|
||||
auto world_landmarks =
|
||||
hand_landmarks_detector_graph.Out(kWorldLandmarksTag);
|
||||
auto hand_rects_for_next_frame =
|
||||
hand_landmarks_detector_graph[Output<std::vector<NormalizedRect>>(
|
||||
kHandRectNextFrameTag)];
|
||||
hand_landmarks_detector_graph.Out(kHandRectNextFrameTag);
|
||||
auto handedness = hand_landmarks_detector_graph.Out(kHandednessTag);
|
||||
|
||||
auto& image_property = graph.AddNode("ImagePropertiesCalculator");
|
||||
image_in >> image_property.In("IMAGE");
|
||||
auto image_size = image_property.Out("SIZE");
|
||||
|
||||
auto& deduplicate = graph.AddNode("HandLandmarksDeduplicationCalculator");
|
||||
landmarks >> deduplicate.In("MULTI_LANDMARKS");
|
||||
world_landmarks >> deduplicate.In("MULTI_WORLD_LANDMARKS");
|
||||
hand_rects_for_next_frame >> deduplicate.In("MULTI_ROIS");
|
||||
handedness >> deduplicate.In("MULTI_CLASSIFICATIONS");
|
||||
image_size >> deduplicate.In("IMAGE_SIZE");
|
||||
|
||||
auto filtered_landmarks =
|
||||
deduplicate[Output<std::vector<NormalizedLandmarkList>>(
|
||||
"MULTI_LANDMARKS")];
|
||||
auto filtered_world_landmarks =
|
||||
deduplicate[Output<std::vector<LandmarkList>>("MULTI_WORLD_LANDMARKS")];
|
||||
auto filtered_hand_rects_for_next_frame =
|
||||
deduplicate[Output<std::vector<NormalizedRect>>("MULTI_ROIS")];
|
||||
auto filtered_handedness =
|
||||
deduplicate[Output<std::vector<ClassificationList>>(
|
||||
"MULTI_CLASSIFICATIONS")];
|
||||
|
||||
// Back edge.
|
||||
hand_rects_for_next_frame >> previous_loopback.In("LOOP");
|
||||
filtered_hand_rects_for_next_frame >> previous_loopback.In("LOOP");
|
||||
|
||||
// TODO: Replace PassThroughCalculator with a calculator that
|
||||
// converts the pixel data to be stored on the target storage (CPU vs GPU).
|
||||
|
@ -259,14 +285,10 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
|||
image_in >> pass_through.In("");
|
||||
|
||||
return {{
|
||||
/* landmark_lists= */ hand_landmarks_detector_graph
|
||||
[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)],
|
||||
/* world_landmark_lists= */
|
||||
hand_landmarks_detector_graph[Output<std::vector<LandmarkList>>(
|
||||
kWorldLandmarksTag)],
|
||||
/* hand_rects_next_frame= */ hand_rects_for_next_frame,
|
||||
hand_landmarks_detector_graph[Output<std::vector<ClassificationList>>(
|
||||
kHandednessTag)],
|
||||
/* landmark_lists= */ filtered_landmarks,
|
||||
/* world_landmark_lists= */ filtered_world_landmarks,
|
||||
/* hand_rects_next_frame= */ filtered_hand_rects_for_next_frame,
|
||||
/* handedness= */ filtered_handedness,
|
||||
/* palm_rects= */
|
||||
hand_detector[Output<std::vector<NormalizedRect>>(kPalmRectsTag)],
|
||||
/* palm_detections */
|
||||
|
|
|
@ -208,7 +208,7 @@ TEST_F(CreateTest, FailsWithMissingModel) {
|
|||
EXPECT_THAT(
|
||||
image_classifier.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name' or 'file_descriptor_meta'."));
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
|
|
|
@ -21,6 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.vision.imageclassifier.proto";
|
||||
option java_outer_classname = "ImageClassifierGraphOptionsProto";
|
||||
|
||||
message ImageClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ImageClassifierGraphOptions ext = 456383383;
|
||||
|
|
|
@ -140,7 +140,7 @@ TEST_F(CreateTest, FailsWithMissingModel) {
|
|||
EXPECT_THAT(
|
||||
image_embedder.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name' or 'file_descriptor_meta'."));
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
|
|
|
@ -191,7 +191,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
|||
EXPECT_THAT(
|
||||
segmenter_or.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name' or 'file_descriptor_meta'."));
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
|
|
|
@ -208,7 +208,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) {
|
|||
EXPECT_THAT(
|
||||
object_detector.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name' or 'file_descriptor_meta'."));
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(object_detector.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
|
|
|
@ -79,3 +79,30 @@ cc_library(
|
|||
"@stblib//:stb_image",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_duplicates_finder",
|
||||
hdrs = ["landmarks_duplicates_finder.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_utils",
|
||||
srcs = ["landmarks_utils.cc"],
|
||||
hdrs = ["landmarks_utils.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/containers:landmarks_detection"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "landmarks_utils_test",
|
||||
srcs = ["landmarks_utils_test.cc"],
|
||||
deps = [
|
||||
":landmarks_utils",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/components/containers:landmarks_detection",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
namespace mediapipe::tasks::vision::utils {
|
||||
|
||||
class DuplicatesFinder {
|
||||
public:
|
||||
virtual ~DuplicatesFinder() = default;
|
||||
// Returns indices of landmark lists to remove to make @multi_landmarks
|
||||
// contain different enough (depending on the implementation) landmark lists
|
||||
// only.
|
||||
virtual absl::StatusOr<absl::flat_hash_set<int>> FindDuplicates(
|
||||
const std::vector<mediapipe::NormalizedLandmarkList>& multi_landmarks,
|
||||
int input_width, int input_height) = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe::tasks::vision::utils
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_
|
48
mediapipe/tasks/cc/vision/utils/landmarks_utils.cc
Normal file
48
mediapipe/tasks/cc/vision/utils/landmarks_utils.cc
Normal file
|
@ -0,0 +1,48 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
namespace mediapipe::tasks::vision::utils {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::Bound;
|
||||
|
||||
float CalculateArea(const Bound& bound) {
|
||||
return (bound.right - bound.left) * (bound.bottom - bound.top);
|
||||
}
|
||||
|
||||
float CalculateIntersectionArea(const Bound& a, const Bound& b) {
|
||||
const float intersection_left = std::max<float>(a.left, b.left);
|
||||
const float intersection_top = std::max<float>(a.top, b.top);
|
||||
const float intersection_right = std::min<float>(a.right, b.right);
|
||||
const float intersection_bottom = std::min<float>(a.bottom, b.bottom);
|
||||
|
||||
return std::max<float>(intersection_bottom - intersection_top, 0.0) *
|
||||
std::max<float>(intersection_right - intersection_left, 0.0);
|
||||
}
|
||||
|
||||
float CalculateIOU(const Bound& a, const Bound& b) {
|
||||
const float area_a = CalculateArea(a);
|
||||
const float area_b = CalculateArea(b);
|
||||
if (area_a <= 0 || area_b <= 0) return 0.0;
|
||||
|
||||
const float intersection_area = CalculateIntersectionArea(a, b);
|
||||
return intersection_area / (area_a + area_b - intersection_area);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tasks::vision::utils
|
41
mediapipe/tasks/cc/vision/utils/landmarks_utils.h
Normal file
41
mediapipe/tasks/cc/vision/utils/landmarks_utils.h
Normal file
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h"
|
||||
|
||||
namespace mediapipe::tasks::vision::utils {
|
||||
|
||||
// Calculates intersection over union for two bounds.
|
||||
float CalculateIOU(const components::containers::Bound& a,
|
||||
const components::containers::Bound& b);
|
||||
|
||||
// Calculates area for face bound
|
||||
float CalculateArea(const components::containers::Bound& bound);
|
||||
|
||||
// Calucates intersection area of two face bounds
|
||||
float CalculateIntersectionArea(const components::containers::Bound& a,
|
||||
const components::containers::Bound& b);
|
||||
} // namespace mediapipe::tasks::vision::utils
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
|
41
mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc
Normal file
41
mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc
Normal file
|
@ -0,0 +1,41 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
|
||||
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
|
||||
namespace mediapipe::tasks::vision::utils {
|
||||
namespace {
|
||||
|
||||
TEST(LandmarkUtilsTest, CalculateIOU) {
|
||||
// Do not intersect
|
||||
EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 2, 3, 3}));
|
||||
// No x intersection
|
||||
EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 0, 3, 1}));
|
||||
// No y intersection
|
||||
EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {0, 2, 1, 3}));
|
||||
// Full intersection
|
||||
EXPECT_EQ(1, CalculateIOU({0, 0, 2, 2}, {0, 0, 2, 2}));
|
||||
|
||||
// Union is 4 intersection is 1
|
||||
EXPECT_EQ(0.25, CalculateIOU({0, 0, 3, 1}, {2, 0, 4, 1}));
|
||||
|
||||
// Same in by y
|
||||
EXPECT_EQ(0.25, CalculateIOU({0, 0, 1, 3}, {0, 2, 1, 4}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::vision::utils
|
|
@ -34,3 +34,32 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "classification_entry",
|
||||
srcs = ["ClassificationEntry.java"],
|
||||
deps = [
|
||||
":category",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "classifications",
|
||||
srcs = ["Classifications.java"],
|
||||
deps = [
|
||||
":classification_entry",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "landmark",
|
||||
srcs = ["Landmark.java"],
|
||||
deps = [
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents a list of predicted categories with an optional timestamp. Typically used as result
|
||||
* for classification tasks.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class ClassificationEntry {
|
||||
/**
|
||||
* Creates a {@link ClassificationEntry} instance from a list of {@link Category} and optional
|
||||
* timestamp.
|
||||
*
|
||||
* @param categories the list of {@link Category} objects that contain category name, display
|
||||
* name, score and label index.
|
||||
* @param timestampMs the {@link long} representing the timestamp for which these categories were
|
||||
* obtained.
|
||||
*/
|
||||
public static ClassificationEntry create(List<Category> categories, long timestampMs) {
|
||||
return new AutoValue_ClassificationEntry(Collections.unmodifiableList(categories), timestampMs);
|
||||
}
|
||||
|
||||
/** The list of predicted {@link Category} objects, sorted by descending score. */
|
||||
public abstract List<Category> categories();
|
||||
|
||||
/**
|
||||
* The timestamp (in milliseconds) associated to the classification entry. This is useful for time
|
||||
* series use cases, e.g. audio classification.
|
||||
*/
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Represents the list of classification for a given classifier head. Typically used as a result for
|
||||
* classification tasks.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class Classifications {
|
||||
|
||||
/**
|
||||
* Creates a {@link Classifications} instance.
|
||||
*
|
||||
* @param entries the list of {@link ClassificationEntry} objects containing the predicted
|
||||
* categories.
|
||||
* @param headIndex the index of the classifier head.
|
||||
* @param headName the name of the classifier head.
|
||||
*/
|
||||
public static Classifications create(
|
||||
List<ClassificationEntry> entries, int headIndex, String headName) {
|
||||
return new AutoValue_Classifications(
|
||||
Collections.unmodifiableList(entries), headIndex, headName);
|
||||
}
|
||||
|
||||
/** A list of {@link ClassificationEntry} objects. */
|
||||
public abstract List<ClassificationEntry> entries();
|
||||
|
||||
/**
|
||||
* The index of the classifier head these entries refer to. This is useful for multi-head models.
|
||||
*/
|
||||
public abstract int headIndex();
|
||||
|
||||
/** The name of the classifier head, which is the corresponding tensor metadata name. */
|
||||
public abstract String headName();
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.containers;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Landmark represents a point in 3D space with x, y, z coordinates. If normalized is true, the
|
||||
* landmark coordinates is normalized respect to the dimension of image, and the coordinates values
|
||||
* are in the range of [0,1]. Otherwise, it represenet a point in world coordinates.
|
||||
*/
|
||||
@AutoValue
|
||||
public abstract class Landmark {
|
||||
private static final float TOLERANCE = 1e-6f;
|
||||
|
||||
public static Landmark create(float x, float y, float z, boolean normalized) {
|
||||
return new AutoValue_Landmark(x, y, z, normalized);
|
||||
}
|
||||
|
||||
// The x coordniates of the landmark.
|
||||
public abstract float x();
|
||||
|
||||
// The y coordniates of the landmark.
|
||||
public abstract float y();
|
||||
|
||||
// The z coordniates of the landmark.
|
||||
public abstract float z();
|
||||
|
||||
// Whether this landmark is normalized with respect to the image size.
|
||||
public abstract boolean normalized();
|
||||
|
||||
@Override
|
||||
public final boolean equals(Object o) {
|
||||
if (!(o instanceof Landmark)) {
|
||||
return false;
|
||||
}
|
||||
Landmark other = (Landmark) o;
|
||||
return other.normalized() == this.normalized()
|
||||
&& Math.abs(other.x() - this.x()) < TOLERANCE
|
||||
&& Math.abs(other.x() - this.y()) < TOLERANCE
|
||||
&& Math.abs(other.x() - this.z()) < TOLERANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public final int hashCode() {
|
||||
return Objects.hash(x(), y(), z(), normalized());
|
||||
}
|
||||
|
||||
@Override
|
||||
public final String toString() {
|
||||
return "<Landmark (x=" + x() + " y=" + y() + " z=" + z() + " normalized=" + normalized() + ")>";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "classifieroptions",
|
||||
srcs = ["ClassifierOptions.java"],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,118 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.components.processors;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.processors.proto.ClassifierOptionsProto;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/** Classifier options shared across MediaPipe Java classification tasks. */
|
||||
@AutoValue
|
||||
public abstract class ClassifierOptions {
|
||||
|
||||
/** Builder for {@link ClassifierOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/**
|
||||
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setDisplayNamesLocale(String locale);
|
||||
|
||||
/**
|
||||
* Sets the optional maximum number of top-scored classification results to return.
|
||||
*
|
||||
* <p>If not set, all available results are returned. If set, must be > 0.
|
||||
*/
|
||||
public abstract Builder setMaxResults(Integer maxResults);
|
||||
|
||||
/**
|
||||
* Sets the optional score threshold. Results with score below this value are rejected.
|
||||
*
|
||||
* <p>Overrides the score threshold specified in the TFLite Model Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setScoreThreshold(Float scoreThreshold);
|
||||
|
||||
/**
|
||||
* Sets the optional allowlist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is not in this set will be filtered
|
||||
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryDenylist}.
|
||||
*/
|
||||
public abstract Builder setCategoryAllowlist(List<String> categoryAllowlist);
|
||||
|
||||
/**
|
||||
* Sets the optional denylist of category names.
|
||||
*
|
||||
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
|
||||
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
|
||||
* categoryAllowlist}.
|
||||
*/
|
||||
public abstract Builder setCategoryDenylist(List<String> categoryDenylist);
|
||||
|
||||
abstract ClassifierOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ClassifierOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if {@link maxResults} is set to a value <= 0.
|
||||
*/
|
||||
public final ClassifierOptions build() {
|
||||
ClassifierOptions options = autoBuild();
|
||||
if (options.maxResults().isPresent() && options.maxResults().get() <= 0) {
|
||||
throw new IllegalArgumentException("If specified, maxResults must be > 0");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
public abstract Optional<String> displayNamesLocale();
|
||||
|
||||
public abstract Optional<Integer> maxResults();
|
||||
|
||||
public abstract Optional<Float> scoreThreshold();
|
||||
|
||||
public abstract List<String> categoryAllowlist();
|
||||
|
||||
public abstract List<String> categoryDenylist();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ClassifierOptions.Builder()
|
||||
.setCategoryAllowlist(Collections.emptyList())
|
||||
.setCategoryDenylist(Collections.emptyList());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassifierOptions} object to a {@link
|
||||
* ClassifierOptionsProto.ClassifierOptions} protobuf message.
|
||||
*/
|
||||
public ClassifierOptionsProto.ClassifierOptions convertToProto() {
|
||||
ClassifierOptionsProto.ClassifierOptions.Builder builder =
|
||||
ClassifierOptionsProto.ClassifierOptions.newBuilder();
|
||||
displayNamesLocale().ifPresent(builder::setDisplayNamesLocale);
|
||||
maxResults().ifPresent(builder::setMaxResults);
|
||||
scoreThreshold().ifPresent(builder::setScoreThreshold);
|
||||
if (!categoryAllowlist().isEmpty()) {
|
||||
builder.addAllCategoryAllowlist(categoryAllowlist());
|
||||
}
|
||||
if (!categoryDenylist().isEmpty()) {
|
||||
builder.addAllCategoryDenylist(categoryDenylist());
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
}
|
|
@ -19,8 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
android_library(
|
||||
name = "core",
|
||||
srcs = glob(["*.java"]),
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
deps = [
|
||||
":libmediapipe_tasks_vision_jni_lib",
|
||||
"//mediapipe/framework/formats:rect_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
|
@ -36,6 +40,7 @@ cc_binary(
|
|||
deps = [
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
|
|
|
@ -14,101 +14,247 @@
|
|||
|
||||
package com.google.mediapipe.tasks.vision.core;
|
||||
|
||||
import android.graphics.RectF;
|
||||
import com.google.mediapipe.formats.proto.RectProto.NormalizedRect;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/** The base class of MediaPipe vision tasks. */
|
||||
public class BaseVisionTaskApi implements AutoCloseable {
|
||||
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
|
||||
private final TaskRunner runner;
|
||||
private final RunningMode runningMode;
|
||||
private final String imageStreamName;
|
||||
private final Optional<String> normRectStreamName;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_vision_jni");
|
||||
ProtoUtil.registerTypeName(NormalizedRect.class, "mediapipe.NormalizedRect");
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision
|
||||
* task {@link RunningMode}.
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi} only taking images as input.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
* @param imageStreamName the name of the input image stream.
|
||||
*/
|
||||
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) {
|
||||
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode, String imageStreamName) {
|
||||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.imageStreamName = imageStreamName;
|
||||
this.normRectStreamName = Optional.empty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link BaseVisionTaskApi} taking images and normalized rects as
|
||||
* input.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
* @param imageStreamName the name of the input image stream.
|
||||
* @param normRectStreamName the name of the input normalized rect image stream.
|
||||
*/
|
||||
public BaseVisionTaskApi(
|
||||
TaskRunner runner,
|
||||
RunningMode runningMode,
|
||||
String imageStreamName,
|
||||
String normRectStreamName) {
|
||||
this.runner = runner;
|
||||
this.runningMode = runningMode;
|
||||
this.imageStreamName = imageStreamName;
|
||||
this.normRectStreamName = Optional.of(normRectStreamName);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
||||
* failure status or a successful result is returned.
|
||||
*
|
||||
* @param imageStreamName the image input stream name.
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @throws MediaPipeException if the task is not in the image mode.
|
||||
* @throws MediaPipeException if the task is not in the image mode or requires a normalized rect
|
||||
* input.
|
||||
*/
|
||||
protected TaskResult processImageData(String imageStreamName, Image image) {
|
||||
protected TaskResult processImageData(Image image) {
|
||||
if (runningMode != RunningMode.IMAGE) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the image mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process single image inputs. The call blocks the current thread until a
|
||||
* failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @throws MediaPipeException if the task is not in the image mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected TaskResult processImageData(Image image, RectF roi) {
|
||||
if (runningMode != RunningMode.IMAGE) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the image mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
return runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process continuous video frames. The call blocks the current thread
|
||||
* until a failure status or a successful result is returned.
|
||||
*
|
||||
* @param imageStreamName the image input stream name.
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode.
|
||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
||||
* input.
|
||||
*/
|
||||
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) {
|
||||
protected TaskResult processVideoData(Image image, long timestampMs) {
|
||||
if (runningMode != RunningMode.VIDEO) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the video mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* A synchronous method to process continuous video frames. The call blocks the current thread
|
||||
* until a failure status or a successful result is returned.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected TaskResult processVideoData(Image image, RectF roi, long timestampMs) {
|
||||
if (runningMode != RunningMode.VIDEO) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the video mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param imageStreamName the image input stream name.
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode.
|
||||
* @throws MediaPipeException if the task is not in the video mode or requires a normalized rect
|
||||
* input.
|
||||
*/
|
||||
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) {
|
||||
protected void sendLiveStreamData(Image image, long timestampMs) {
|
||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the live stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task expects a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/**
|
||||
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
|
||||
* available in the user-defined result listener.
|
||||
*
|
||||
* @param image a MediaPipe {@link Image} object for processing.
|
||||
* @param roi a {@link RectF} defining the region-of-interest to process in the image. Coordinates
|
||||
* are expected to be specified as normalized values in [0,1].
|
||||
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
|
||||
* @throws MediaPipeException if the task is not in the video mode or doesn't require a normalized
|
||||
* rect.
|
||||
*/
|
||||
protected void sendLiveStreamData(Image image, RectF roi, long timestampMs) {
|
||||
if (runningMode != RunningMode.LIVE_STREAM) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task is not initialized with the live stream mode. Current running mode:"
|
||||
+ runningMode.name());
|
||||
}
|
||||
if (!normRectStreamName.isPresent()) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
|
||||
"Task doesn't expect a normalized rect as input.");
|
||||
}
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
|
||||
inputPackets.put(
|
||||
normRectStreamName.get(),
|
||||
runner.getPacketCreator().createProto(convertToNormalizedRect(roi)));
|
||||
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
|
||||
}
|
||||
|
||||
/** Closes and cleans up the MediaPipe vision task. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
}
|
||||
|
||||
/** Converts a {@link RectF} object into a {@link NormalizedRect} protobuf message. */
|
||||
private static NormalizedRect convertToNormalizedRect(RectF rect) {
|
||||
return NormalizedRect.newBuilder()
|
||||
.setXCenter(rect.centerX())
|
||||
.setYCenter(rect.centerY())
|
||||
.setWidth(rect.width())
|
||||
.setHeight(rect.height())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.gesturerecognizer">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
android_library(
|
||||
name = "gesturerecognizer",
|
||||
srcs = [
|
||||
"GestureRecognitionResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = ":AndroidManifest.xml",
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,128 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.gesturerecognizer;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.Landmark;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the gesture recognition results generated by {@link GestureRecognizer}. */
|
||||
@AutoValue
|
||||
public abstract class GestureRecognitionResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates a {@link GestureRecognitionResult} instance from the lists of landmarks, handedness,
|
||||
* and gestures protobuf messages.
|
||||
*
|
||||
* @param landmarksProto a List of {@link NormalizedLandmarkList}
|
||||
* @param worldLandmarksProto a List of {@link LandmarkList}
|
||||
* @param handednessesProto a List of {@link ClassificationList}
|
||||
* @param gesturesProto a List of {@link ClassificationList}
|
||||
*/
|
||||
static GestureRecognitionResult create(
|
||||
List<NormalizedLandmarkList> landmarksProto,
|
||||
List<LandmarkList> worldLandmarksProto,
|
||||
List<ClassificationList> handednessesProto,
|
||||
List<ClassificationList> gesturesProto,
|
||||
long timestampMs) {
|
||||
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandLandmarks =
|
||||
new ArrayList<>();
|
||||
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandWorldLandmarks =
|
||||
new ArrayList<>();
|
||||
List<List<Category>> multiHandHandednesses = new ArrayList<>();
|
||||
List<List<Category>> multiHandGestures = new ArrayList<>();
|
||||
for (NormalizedLandmarkList handLandmarksProto : landmarksProto) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Landmark> handLandmarks =
|
||||
new ArrayList<>();
|
||||
multiHandLandmarks.add(handLandmarks);
|
||||
for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) {
|
||||
handLandmarks.add(
|
||||
com.google.mediapipe.tasks.components.containers.Landmark.create(
|
||||
handLandmarkProto.getX(),
|
||||
handLandmarkProto.getY(),
|
||||
handLandmarkProto.getZ(),
|
||||
true));
|
||||
}
|
||||
}
|
||||
for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Landmark> handWorldLandmarks =
|
||||
new ArrayList<>();
|
||||
multiHandWorldLandmarks.add(handWorldLandmarks);
|
||||
for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) {
|
||||
handWorldLandmarks.add(
|
||||
com.google.mediapipe.tasks.components.containers.Landmark.create(
|
||||
handWorldLandmarkProto.getX(),
|
||||
handWorldLandmarkProto.getY(),
|
||||
handWorldLandmarkProto.getZ(),
|
||||
false));
|
||||
}
|
||||
}
|
||||
for (ClassificationList handednessProto : handednessesProto) {
|
||||
List<Category> handedness = new ArrayList<>();
|
||||
multiHandHandednesses.add(handedness);
|
||||
for (Classification classification : handednessProto.getClassificationList()) {
|
||||
handedness.add(
|
||||
Category.create(
|
||||
classification.getScore(),
|
||||
classification.getIndex(),
|
||||
classification.getLabel(),
|
||||
classification.getDisplayName()));
|
||||
}
|
||||
}
|
||||
for (ClassificationList gestureProto : gesturesProto) {
|
||||
List<Category> gestures = new ArrayList<>();
|
||||
multiHandGestures.add(gestures);
|
||||
for (Classification classification : gestureProto.getClassificationList()) {
|
||||
gestures.add(
|
||||
Category.create(
|
||||
classification.getScore(),
|
||||
classification.getIndex(),
|
||||
classification.getLabel(),
|
||||
classification.getDisplayName()));
|
||||
}
|
||||
}
|
||||
return new AutoValue_GestureRecognitionResult(
|
||||
timestampMs,
|
||||
Collections.unmodifiableList(multiHandLandmarks),
|
||||
Collections.unmodifiableList(multiHandWorldLandmarks),
|
||||
Collections.unmodifiableList(multiHandHandednesses),
|
||||
Collections.unmodifiableList(multiHandGestures));
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
/** Hand landmarks of detected hands. */
|
||||
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>> landmarks();
|
||||
|
||||
/** Hand landmarks in world coordniates of detected hands. */
|
||||
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>>
|
||||
worldLandmarks();
|
||||
|
||||
/** Handedness of detected hands. */
|
||||
public abstract List<List<Category>> handednesses();
|
||||
|
||||
/** Recognized hand gestures of detected hands */
|
||||
public abstract List<List<Category>> gestures();
|
||||
}
|
|
@ -38,7 +38,8 @@ public abstract class ObjectDetectionResult implements TaskResult {
|
|||
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
|
||||
* messages.
|
||||
*
|
||||
* @param detectionList a list of {@link Detection} protobuf messages.
|
||||
* @param detectionList a list of {@link DetectionOuterClass.Detection} protobuf messages.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
|
||||
|
|
|
@ -155,7 +155,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param detectorOptions a {@link ObjectDetectorOptions} instance.
|
||||
* @param detectorOptions an {@link ObjectDetectorOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
|
||||
*/
|
||||
public static ObjectDetector createFromOptions(
|
||||
|
@ -192,7 +192,6 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
|
||||
return new ObjectDetector(runner, detectorOptions.runningMode());
|
||||
}
|
||||
|
||||
|
@ -204,7 +203,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode);
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -221,7 +220,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detect(Image inputImage) {
|
||||
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage);
|
||||
return (ObjectDetectionResult) processImageData(inputImage);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -242,8 +241,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
|
||||
return (ObjectDetectionResult)
|
||||
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
|
||||
return (ObjectDetectionResult) processVideoData(inputImage, inputTimestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -265,7 +263,7 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(Image inputImage, long inputTimestampMs) {
|
||||
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
|
||||
sendLiveStreamData(inputImage, inputTimestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ObjectDetector}. */
|
||||
|
@ -275,12 +273,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
/** Builder for {@link ObjectDetectorOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the object detector task. */
|
||||
/** Sets the {@link BaseOptions} for the object detector task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the running mode for the object detector task. Default to the image mode. Object
|
||||
* detector has three modes:
|
||||
* Sets the {@link RunningMode} for the object detector task. Default to the image mode.
|
||||
* Object detector has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for detecting objects on single image inputs.
|
||||
|
@ -293,8 +291,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
public abstract Builder setRunningMode(RunningMode value);
|
||||
|
||||
/**
|
||||
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
|
||||
* any. Defaults to English.
|
||||
* Sets the optional locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any.
|
||||
*/
|
||||
public abstract Builder setDisplayNamesLocale(String value);
|
||||
|
||||
|
@ -331,12 +329,12 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
public abstract Builder setCategoryDenylist(List<String> value);
|
||||
|
||||
/**
|
||||
* Sets the result listener to receive the detection results asynchronously when the object
|
||||
* detector is in the live stream mode.
|
||||
* Sets the {@link ResultListener} to receive the detection results asynchronously when the
|
||||
* object detector is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
|
||||
|
||||
/** Sets an optional error listener. */
|
||||
/** Sets an optional {@link ErrorListener}}. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract ObjectDetectorOptions autoBuild();
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="facedetectiontest"
|
||||
android:label="objectdetectortest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test util for MediaPipe Tasks."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import flags
|
||||
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import image_frame as image_frame_module
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
_Image = image_module.Image
|
||||
_ImageFormat = image_frame_module.ImageFormat
|
||||
_RGB_CHANNELS = 3
|
||||
|
||||
|
||||
def test_srcdir():
|
||||
"""Returns the path where to look for test data files."""
|
||||
if "test_srcdir" in flags.FLAGS:
|
||||
return flags.FLAGS["test_srcdir"].value
|
||||
elif "TEST_SRCDIR" in os.environ:
|
||||
return os.environ["TEST_SRCDIR"]
|
||||
else:
|
||||
raise RuntimeError("Missing TEST_SRCDIR environment.")
|
||||
|
||||
|
||||
def get_test_data_path(file_or_dirname: str) -> str:
|
||||
"""Returns full test data path."""
|
||||
for (directory, subdirs, files) in os.walk(test_srcdir()):
|
||||
for f in subdirs + files:
|
||||
if f.endswith(file_or_dirname):
|
||||
return os.path.join(directory, f)
|
||||
raise ValueError("No %s in test directory" % file_or_dirname)
|
|
@ -119,7 +119,7 @@ class ObjectDetectorTest(parameterized.TestCase):
|
|||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r"ExternalFile must specify at least one of 'file_content', "
|
||||
r"'file_name' or 'file_descriptor_meta'."):
|
||||
r"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."):
|
||||
base_options = _BaseOptions(model_asset_path='')
|
||||
options = _ObjectDetectorOptions(base_options=base_options)
|
||||
_ObjectDetector.create_from_options(options)
|
||||
|
|
1
mediapipe/tasks/testdata/text/BUILD
vendored
1
mediapipe/tasks/testdata/text/BUILD
vendored
|
@ -27,6 +27,7 @@ mediapipe_files(srcs = [
|
|||
"albert_with_metadata.tflite",
|
||||
"bert_text_classifier.tflite",
|
||||
"mobilebert_with_metadata.tflite",
|
||||
"test_model_text_classifier_bool_output.tflite",
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite",
|
||||
])
|
||||
|
||||
|
|
4
mediapipe/tasks/testdata/vision/BUILD
vendored
4
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -47,6 +47,7 @@ mediapipe_files(srcs = [
|
|||
"mozart_square.jpg",
|
||||
"multi_objects.jpg",
|
||||
"palm_detection_full.tflite",
|
||||
"pointing_up.jpg",
|
||||
"right_hands.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
"segmentation_input_rotation0.jpg",
|
||||
|
@ -54,6 +55,7 @@ mediapipe_files(srcs = [
|
|||
"selfie_segm_128_128_3_expected_mask.jpg",
|
||||
"selfie_segm_144_256_3.tflite",
|
||||
"selfie_segm_144_256_3_expected_mask.jpg",
|
||||
"thumb_up.jpg",
|
||||
])
|
||||
|
||||
exports_files(
|
||||
|
@ -79,11 +81,13 @@ filegroup(
|
|||
"left_hands.jpg",
|
||||
"mozart_square.jpg",
|
||||
"multi_objects.jpg",
|
||||
"pointing_up.jpg",
|
||||
"right_hands.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
"segmentation_input_rotation0.jpg",
|
||||
"selfie_segm_128_128_3_expected_mask.jpg",
|
||||
"selfie_segm_144_256_3_expected_mask.jpg",
|
||||
"thumb_up.jpg",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe/python:__subpackages__",
|
||||
|
|
|
@ -8,216 +8,216 @@ classifications {
|
|||
|
||||
landmarks {
|
||||
landmark {
|
||||
x: 0.4749803
|
||||
y: 0.76872
|
||||
z: 9.286178e-08
|
||||
x: 0.47923622
|
||||
y: 0.7426044
|
||||
z: 2.3221878e-07
|
||||
}
|
||||
landmark {
|
||||
x: 0.5466898
|
||||
y: 0.6706463
|
||||
z: -0.03454024
|
||||
x: 0.5403745
|
||||
y: 0.66178805
|
||||
z: -0.044572093
|
||||
}
|
||||
landmark {
|
||||
x: 0.5890165
|
||||
y: 0.5604909
|
||||
z: -0.055142127
|
||||
x: 0.5774534
|
||||
y: 0.5608346
|
||||
z: -0.07581605
|
||||
}
|
||||
landmark {
|
||||
x: 0.52780133
|
||||
y: 0.49855334
|
||||
z: -0.07846409
|
||||
x: 0.52648556
|
||||
y: 0.50247055
|
||||
z: -0.105467044
|
||||
}
|
||||
landmark {
|
||||
x: 0.44487286
|
||||
y: 0.49801928
|
||||
z: -0.10188004
|
||||
x: 0.44289914
|
||||
y: 0.49489295
|
||||
z: -0.13422011
|
||||
}
|
||||
landmark {
|
||||
x: 0.47572923
|
||||
y: 0.44477755
|
||||
z: -0.028345175
|
||||
x: 0.4728853
|
||||
y: 0.43925008
|
||||
z: -0.058122505
|
||||
}
|
||||
landmark {
|
||||
x: 0.48013464
|
||||
y: 0.32467923
|
||||
z: -0.06513901
|
||||
x: 0.4803168
|
||||
y: 0.32889345
|
||||
z: -0.101187326
|
||||
}
|
||||
landmark {
|
||||
x: 0.48351905
|
||||
y: 0.25804192
|
||||
z: -0.086756624
|
||||
x: 0.48436823
|
||||
y: 0.25876504
|
||||
z: -0.12840955
|
||||
}
|
||||
landmark {
|
||||
x: 0.47760454
|
||||
y: 0.19289327
|
||||
z: -0.10468461
|
||||
x: 0.47388697
|
||||
y: 0.19592366
|
||||
z: -0.15085006
|
||||
}
|
||||
landmark {
|
||||
x: 0.3993108
|
||||
y: 0.47566867
|
||||
z: -0.040357687
|
||||
x: 0.39129356
|
||||
y: 0.47211456
|
||||
z: -0.06835801
|
||||
}
|
||||
landmark {
|
||||
x: 0.42361537
|
||||
y: 0.42491958
|
||||
z: -0.103545874
|
||||
x: 0.41798547
|
||||
y: 0.42218646
|
||||
z: -0.12954563
|
||||
}
|
||||
landmark {
|
||||
x: 0.46059948
|
||||
y: 0.51723665
|
||||
z: -0.1214961
|
||||
x: 0.45758423
|
||||
y: 0.5232461
|
||||
z: -0.14131334
|
||||
}
|
||||
landmark {
|
||||
x: 0.4580545
|
||||
y: 0.55640894
|
||||
z: -0.12272568
|
||||
x: 0.45100626
|
||||
y: 0.5554065
|
||||
z: -0.13883406
|
||||
}
|
||||
landmark {
|
||||
x: 0.34109607
|
||||
y: 0.5184511
|
||||
z: -0.056422118
|
||||
x: 0.33133638
|
||||
y: 0.51777464
|
||||
z: -0.08227023
|
||||
}
|
||||
landmark {
|
||||
x: 0.36177525
|
||||
y: 0.48427337
|
||||
z: -0.12584248
|
||||
x: 0.35698116
|
||||
y: 0.48688585
|
||||
z: -0.14713185
|
||||
}
|
||||
landmark {
|
||||
x: 0.40706652
|
||||
y: 0.5700621
|
||||
z: -0.11658718
|
||||
x: 0.40754414
|
||||
y: 0.57370347
|
||||
z: -0.12981415
|
||||
}
|
||||
landmark {
|
||||
x: 0.40535083
|
||||
y: 0.6000496
|
||||
z: -0.09520916
|
||||
x: 0.40011865
|
||||
y: 0.5930706
|
||||
z: -0.10554546
|
||||
}
|
||||
landmark {
|
||||
x: 0.2872031
|
||||
y: 0.57303333
|
||||
z: -0.074813806
|
||||
x: 0.2783401
|
||||
y: 0.5735568
|
||||
z: -0.09971398
|
||||
}
|
||||
landmark {
|
||||
x: 0.30961618
|
||||
y: 0.533245
|
||||
z: -0.114366606
|
||||
x: 0.30884498
|
||||
y: 0.5394487
|
||||
z: -0.14033116
|
||||
}
|
||||
landmark {
|
||||
x: 0.35510173
|
||||
y: 0.5838698
|
||||
z: -0.096521005
|
||||
x: 0.35470563
|
||||
y: 0.5917965
|
||||
z: -0.11820527
|
||||
}
|
||||
landmark {
|
||||
x: 0.36053744
|
||||
y: 0.608682
|
||||
z: -0.07574715
|
||||
x: 0.34865493
|
||||
y: 0.61057556
|
||||
z: -0.09509217
|
||||
}
|
||||
}
|
||||
|
||||
world_landmarks {
|
||||
landmark {
|
||||
x: 0.018890835
|
||||
y: 0.09005852
|
||||
z: 0.031907097
|
||||
x: 0.016918864
|
||||
y: 0.08634466
|
||||
z: 0.035783045
|
||||
}
|
||||
landmark {
|
||||
x: 0.04198891
|
||||
y: 0.061256267
|
||||
z: 0.017695501
|
||||
x: 0.04193685
|
||||
y: 0.056667875
|
||||
z: 0.019453367
|
||||
}
|
||||
landmark {
|
||||
x: 0.05044507
|
||||
y: 0.033841074
|
||||
z: 0.0015051212
|
||||
x: 0.050382353
|
||||
y: 0.031786427
|
||||
z: 0.0023380776
|
||||
}
|
||||
landmark {
|
||||
x: 0.039822325
|
||||
y: 0.0073827556
|
||||
z: -0.02168335
|
||||
x: 0.043284662
|
||||
y: 0.008976387
|
||||
z: -0.02496663
|
||||
}
|
||||
landmark {
|
||||
x: 0.012921701
|
||||
y: 0.0025111444
|
||||
z: -0.033813436
|
||||
x: 0.016010094
|
||||
y: 0.004991216
|
||||
z: -0.036876947
|
||||
}
|
||||
landmark {
|
||||
x: 0.023851154
|
||||
y: -0.011495698
|
||||
z: 0.0066048754
|
||||
x: 0.02450771
|
||||
y: -0.013496464
|
||||
z: 0.0041254223
|
||||
}
|
||||
landmark {
|
||||
x: 0.023206754
|
||||
y: -0.042496294
|
||||
z: -0.0026847485
|
||||
x: 0.024783865
|
||||
y: -0.041331705
|
||||
z: -0.0028748964
|
||||
}
|
||||
landmark {
|
||||
x: 0.02298078
|
||||
y: -0.062678955
|
||||
z: -0.013068148
|
||||
x: 0.025917178
|
||||
y: -0.06191107
|
||||
z: -0.010242647
|
||||
}
|
||||
landmark {
|
||||
x: 0.021972645
|
||||
y: -0.08151748
|
||||
z: -0.03677687
|
||||
x: 0.023101516
|
||||
y: -0.07967696
|
||||
z: -0.03152665
|
||||
}
|
||||
landmark {
|
||||
x: -0.00016964211
|
||||
y: -0.005549716
|
||||
z: 0.0058569373
|
||||
x: 0.0006629339
|
||||
y: -0.0060150283
|
||||
z: 0.004906766
|
||||
}
|
||||
landmark {
|
||||
x: 0.0075052455
|
||||
y: -0.020031122
|
||||
z: -0.027775772
|
||||
x: 0.0077093104
|
||||
y: -0.017035034
|
||||
z: -0.029702934
|
||||
}
|
||||
landmark {
|
||||
x: 0.017835317
|
||||
y: 0.004899453
|
||||
z: -0.037390795
|
||||
x: 0.017517095
|
||||
y: 0.008997183
|
||||
z: -0.03692814
|
||||
}
|
||||
landmark {
|
||||
x: 0.016913192
|
||||
y: 0.018281722
|
||||
z: -0.019302163
|
||||
x: 0.0145079205
|
||||
y: 0.017461296
|
||||
z: -0.011290487
|
||||
}
|
||||
landmark {
|
||||
x: -0.018799124
|
||||
y: 0.0053577404
|
||||
z: -0.0040608873
|
||||
x: -0.018095909
|
||||
y: 0.006112392
|
||||
z: -0.0027157406
|
||||
}
|
||||
landmark {
|
||||
x: -0.00747582
|
||||
y: 0.0019600953
|
||||
z: -0.034023333
|
||||
x: -0.010212201
|
||||
y: 0.0052777785
|
||||
z: -0.034659054
|
||||
}
|
||||
landmark {
|
||||
x: 0.0035368819
|
||||
y: 0.025736088
|
||||
z: -0.03452471
|
||||
x: 0.0043836404
|
||||
y: 0.028383566
|
||||
z: -0.03296758
|
||||
}
|
||||
landmark {
|
||||
x: 0.0080153765
|
||||
y: 0.039885145
|
||||
z: -0.013341276
|
||||
x: 0.003886811
|
||||
y: 0.036054
|
||||
z: -0.0074628904
|
||||
}
|
||||
landmark {
|
||||
x: -0.029628165
|
||||
y: 0.028607829
|
||||
z: -0.011377414
|
||||
x: -0.03178849
|
||||
y: 0.029854178
|
||||
z: -0.008874044
|
||||
}
|
||||
landmark {
|
||||
x: -0.023356002
|
||||
y: 0.017514031
|
||||
z: -0.029408533
|
||||
x: -0.02403016
|
||||
y: 0.021497255
|
||||
z: -0.027618393
|
||||
}
|
||||
landmark {
|
||||
x: -0.008503268
|
||||
y: 0.027560957
|
||||
z: -0.035641473
|
||||
x: -0.008522437
|
||||
y: 0.031886857
|
||||
z: -0.032367583
|
||||
}
|
||||
landmark {
|
||||
x: -0.0070180474
|
||||
y: 0.039056484
|
||||
z: -0.023629948
|
||||
x: -0.012865841
|
||||
y: 0.038687646
|
||||
z: -0.017172804
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,216 +8,216 @@ classifications {
|
|||
|
||||
landmarks {
|
||||
landmark {
|
||||
x: 0.6065784
|
||||
y: 0.7356081
|
||||
z: -5.2289305e-08
|
||||
x: 0.6387502
|
||||
y: 0.67134184
|
||||
z: -3.4044612e-07
|
||||
}
|
||||
landmark {
|
||||
x: 0.6349347
|
||||
y: 0.5735343
|
||||
z: -0.047243003
|
||||
x: 0.634891
|
||||
y: 0.53670025
|
||||
z: -0.06968865
|
||||
}
|
||||
landmark {
|
||||
x: 0.5788341
|
||||
y: 0.42688707
|
||||
z: -0.036071796
|
||||
x: 0.5746676
|
||||
y: 0.41283816
|
||||
z: -0.09383486
|
||||
}
|
||||
landmark {
|
||||
x: 0.51322824
|
||||
y: 0.3153786
|
||||
z: -0.021018881
|
||||
x: 0.49967948
|
||||
y: 0.32550922
|
||||
z: -0.10799447
|
||||
}
|
||||
landmark {
|
||||
x: 0.49179295
|
||||
y: 0.25291175
|
||||
z: 0.0061425082
|
||||
x: 0.47362617
|
||||
y: 0.25102285
|
||||
z: -0.10590933
|
||||
}
|
||||
landmark {
|
||||
x: 0.49944243
|
||||
y: 0.45409226
|
||||
z: 0.06513325
|
||||
x: 0.40749234
|
||||
y: 0.47130388
|
||||
z: -0.04694611
|
||||
}
|
||||
landmark {
|
||||
x: 0.3822241
|
||||
y: 0.45645967
|
||||
z: 0.045028925
|
||||
x: 0.3372087
|
||||
y: 0.46742308
|
||||
z: -0.0997342
|
||||
}
|
||||
landmark {
|
||||
x: 0.4427338
|
||||
y: 0.49150866
|
||||
z: 0.024395633
|
||||
x: 0.4418445
|
||||
y: 0.50960016
|
||||
z: -0.111206524
|
||||
}
|
||||
landmark {
|
||||
x: 0.5015556
|
||||
y: 0.4798539
|
||||
z: 0.014423937
|
||||
x: 0.48056933
|
||||
y: 0.5187666
|
||||
z: -0.11022365
|
||||
}
|
||||
landmark {
|
||||
x: 0.46654877
|
||||
y: 0.5420721
|
||||
z: 0.08380699
|
||||
x: 0.39218128
|
||||
y: 0.5495232
|
||||
z: -0.028925514
|
||||
}
|
||||
landmark {
|
||||
x: 0.3540949
|
||||
y: 0.545657
|
||||
z: 0.056201216
|
||||
x: 0.34047198
|
||||
y: 0.55610204
|
||||
z: -0.08213869
|
||||
}
|
||||
landmark {
|
||||
x: 0.43828446
|
||||
y: 0.5723222
|
||||
z: 0.03073385
|
||||
x: 0.46152583
|
||||
y: 0.58310646
|
||||
z: -0.08393028
|
||||
}
|
||||
landmark {
|
||||
x: 0.4894746
|
||||
y: 0.54662794
|
||||
z: 0.016284892
|
||||
x: 0.47058716
|
||||
y: 0.56413835
|
||||
z: -0.078857616
|
||||
}
|
||||
landmark {
|
||||
x: 0.44287524
|
||||
y: 0.6153337
|
||||
z: 0.0878331
|
||||
x: 0.39237642
|
||||
y: 0.61864823
|
||||
z: -0.022026168
|
||||
}
|
||||
landmark {
|
||||
x: 0.3531985
|
||||
y: 0.6305228
|
||||
z: 0.048528627
|
||||
x: 0.34304678
|
||||
y: 0.62800515
|
||||
z: -0.08132204
|
||||
}
|
||||
landmark {
|
||||
x: 0.42727134
|
||||
y: 0.64344436
|
||||
z: 0.027383275
|
||||
x: 0.45004016
|
||||
y: 0.64300805
|
||||
z: -0.06211204
|
||||
}
|
||||
landmark {
|
||||
x: 0.46999624
|
||||
y: 0.61115295
|
||||
z: 0.021795912
|
||||
x: 0.4640005
|
||||
y: 0.6221539
|
||||
z: -0.038953774
|
||||
}
|
||||
landmark {
|
||||
x: 0.43323213
|
||||
y: 0.6734935
|
||||
z: 0.087731235
|
||||
x: 0.39231628
|
||||
y: 0.68187976
|
||||
z: -0.020164328
|
||||
}
|
||||
landmark {
|
||||
x: 0.3772134
|
||||
y: 0.69590896
|
||||
z: 0.07259013
|
||||
x: 0.35785866
|
||||
y: 0.6985842
|
||||
z: -0.052247807
|
||||
}
|
||||
landmark {
|
||||
x: 0.42301077
|
||||
y: 0.70083475
|
||||
z: 0.06279105
|
||||
x: 0.42698768
|
||||
y: 0.69892275
|
||||
z: -0.037642766
|
||||
}
|
||||
landmark {
|
||||
x: 0.45672464
|
||||
y: 0.6844607
|
||||
z: 0.059202813
|
||||
x: 0.44422707
|
||||
y: 0.6876204
|
||||
z: -0.02034688
|
||||
}
|
||||
}
|
||||
|
||||
world_landmarks {
|
||||
landmark {
|
||||
x: 0.047059614
|
||||
y: 0.04719348
|
||||
z: 0.03951376
|
||||
x: 0.06753889
|
||||
y: 0.031051591
|
||||
z: 0.05541924
|
||||
}
|
||||
landmark {
|
||||
x: 0.050449535
|
||||
y: 0.012183173
|
||||
z: 0.016567508
|
||||
x: 0.06327636
|
||||
y: -0.003913434
|
||||
z: 0.02125023
|
||||
}
|
||||
landmark {
|
||||
x: 0.04375921
|
||||
y: -0.020305036
|
||||
z: 0.012189768
|
||||
x: 0.05469646
|
||||
y: -0.038668767
|
||||
z: 0.01118496
|
||||
}
|
||||
landmark {
|
||||
x: 0.022525383
|
||||
y: -0.04830697
|
||||
z: 0.008714083
|
||||
x: 0.03557241
|
||||
y: -0.06865983
|
||||
z: 0.0029562893
|
||||
}
|
||||
landmark {
|
||||
x: 0.011789754
|
||||
y: -0.06952699
|
||||
z: 0.0029319536
|
||||
x: 0.019069858
|
||||
y: -0.08740239
|
||||
z: 0.007222481
|
||||
}
|
||||
landmark {
|
||||
x: 0.009532374
|
||||
y: -0.019510617
|
||||
z: 0.0015609035
|
||||
x: 0.0044852756
|
||||
y: -0.02772763
|
||||
z: -0.004234833
|
||||
}
|
||||
landmark {
|
||||
x: -0.007894232
|
||||
y: -0.022080563
|
||||
z: -0.014592148
|
||||
x: -0.0031203926
|
||||
y: -0.024173645
|
||||
z: -0.033932913
|
||||
}
|
||||
landmark {
|
||||
x: -0.002826123
|
||||
y: -0.019949362
|
||||
z: -0.009392118
|
||||
x: 0.0080217365
|
||||
y: -0.018939625
|
||||
z: -0.032623816
|
||||
}
|
||||
landmark {
|
||||
x: 0.009066351
|
||||
y: -0.016403511
|
||||
z: 0.005516675
|
||||
x: 0.025537387
|
||||
y: -0.014517117
|
||||
z: -0.004398854
|
||||
}
|
||||
landmark {
|
||||
x: -0.0031000748
|
||||
y: -0.003971943
|
||||
z: 0.004851345
|
||||
x: -0.004470923
|
||||
y: -0.0040212176
|
||||
z: 0.0025033879
|
||||
}
|
||||
landmark {
|
||||
x: -0.016852753
|
||||
y: -0.009905987
|
||||
z: -0.016275175
|
||||
x: -0.010845158
|
||||
y: -0.0031857258
|
||||
z: -0.036282137
|
||||
}
|
||||
landmark {
|
||||
x: -0.006703893
|
||||
y: -0.0026965735
|
||||
z: -0.015606856
|
||||
x: 0.016729971
|
||||
y: 0.0028876318
|
||||
z: -0.036264844
|
||||
}
|
||||
landmark {
|
||||
x: 0.007890566
|
||||
y: -0.010418876
|
||||
z: 0.0050479355
|
||||
x: 0.019928008
|
||||
y: -0.0032422952
|
||||
z: 0.004380459
|
||||
}
|
||||
landmark {
|
||||
x: -0.007842411
|
||||
y: 0.011552694
|
||||
z: -0.0005755241
|
||||
x: -0.005686749
|
||||
y: 0.017101247
|
||||
z: 0.0036791638
|
||||
}
|
||||
landmark {
|
||||
x: -0.021125216
|
||||
y: 0.009268615
|
||||
z: -0.017993882
|
||||
x: -0.010514952
|
||||
y: 0.017355483
|
||||
z: -0.02882688
|
||||
}
|
||||
landmark {
|
||||
x: -0.006585305
|
||||
y: 0.013378072
|
||||
z: -0.01709412
|
||||
x: 0.014503509
|
||||
y: 0.019414417
|
||||
z: -0.026207235
|
||||
}
|
||||
landmark {
|
||||
x: 0.008140431
|
||||
y: 0.008364402
|
||||
z: -0.0051898304
|
||||
x: 0.0211232
|
||||
y: 0.014327417
|
||||
z: 0.0011467658
|
||||
}
|
||||
landmark {
|
||||
x: -0.01082343
|
||||
y: 0.03213215
|
||||
z: -0.00069864903
|
||||
x: 0.0011399705
|
||||
y: 0.043651186
|
||||
z: 0.0068390737
|
||||
}
|
||||
landmark {
|
||||
x: -0.0199164
|
||||
y: 0.028296603
|
||||
z: -0.01447433
|
||||
x: -0.010388309
|
||||
y: 0.03904784
|
||||
z: -0.015677728
|
||||
}
|
||||
landmark {
|
||||
x: -0.00960456
|
||||
y: 0.026734762
|
||||
z: -0.019243335
|
||||
x: 0.006957108
|
||||
y: 0.03613425
|
||||
z: -0.028704688
|
||||
}
|
||||
landmark {
|
||||
x: 0.0040425956
|
||||
y: 0.025051914
|
||||
z: -0.014775545
|
||||
x: 0.012793289
|
||||
y: 0.03930679
|
||||
z: -0.012465539
|
||||
}
|
||||
}
|
||||
|
|
14
third_party/external_files.bzl
vendored
14
third_party/external_files.bzl
vendored
|
@ -432,8 +432,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pointing_up_landmarks_pbtxt",
|
||||
sha256 = "1255b6ba17b4ef7a9b3ce92c0a139e74fbcec272dc251b049b2f06732f9fed83",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1662650664573638"],
|
||||
sha256 = "a3cd7f088a9e997dbb8f00d91dbf3faaacbdb262c8f2fde3c07a9d0656488065",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1665174976408451"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -562,6 +562,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_add_op.tflite?generation=1661875950076192"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_test_model_text_classifier_bool_output_tflite",
|
||||
sha256 = "09877ac6d718d78da6380e21fe8179854909d116632d6d770c12f8a51792e310",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/test_model_text_classifier_bool_output.tflite?generation=1664904110313163"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_test_model_text_classifier_with_regex_tokenizer_tflite",
|
||||
sha256 = "cb12618d084b813cb7b90ceb39c9fe4b18dae4de9880b912cdcd4b577cd65b4f",
|
||||
|
@ -588,8 +594,8 @@ def external_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_thumb_up_landmarks_pbtxt",
|
||||
sha256 = "bf1913df6ac7cc14b492c10411c827832839985c057b112789e04ce7c1fdd0fa",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1662650669387278"],
|
||||
sha256 = "b129ae0536be4e25d6cdee74aabe9dedf1bcfe87430a40b68be4079db3a4d926",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1665174979747784"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
|
Loading…
Reference in New Issue
Block a user