Merge branch 'master' into image-embedder-python

This commit is contained in:
Kinar R 2022-11-10 16:04:28 +05:30 committed by GitHub
commit 0a6e21c212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 34103 additions and 306 deletions

View File

@ -546,3 +546,6 @@ rules_proto_toolchains()
load("//third_party:external_files.bzl", "external_files") load("//third_party:external_files.bzl", "external_files")
external_files() external_files()
load("//third_party:wasm_files.bzl", "wasm_files")
wasm_files()

View File

@ -200,3 +200,38 @@ cc_test(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
], ],
) )
cc_library(
name = "embedding_aggregation_calculator",
srcs = ["embedding_aggregation_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"@com_google_absl//absl/status",
],
alwayslink = 1,
)
cc_test(
name = "embedding_aggregation_calculator_test",
srcs = ["embedding_aggregation_calculator_test.cc"],
deps = [
":embedding_aggregation_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:output_stream_poller",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)

View File

@ -0,0 +1,132 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_map>
#include <vector>
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
// Aggregates EmbeddingResult packets into a vector of timestamped
// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
// needed.
//
// Inputs:
// EMBEDDINGS: EmbeddingResult
// The EmbeddingResult packets to aggregate.
// TIMESTAMPS: std::vector<Timestamp> @Optional.
// The collection of timestamps that this calculator should aggregate. This
// stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
// will contain the aggregated results. Otherwise as no timestamp
// aggregation is required the EMBEDDINGS output is used to pass the inputs
// EmbeddingResults unchanged.
//
// Outputs:
// EMBEDDINGS: EmbeddingResult @Optional
// The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
// input is not connected, as it signals that timestamp aggregation is not
// required.
// TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
// The embedding results aggregated by timestamp. Must be connected if the
// TIMESTAMPS input is connected as it signals that timestamp aggregation is
// required.
//
// Example without timestamp aggregation (pass-through):
// node {
// calculator: "EmbeddingAggregationCalculator"
// input_stream: "EMBEDDINGS:embeddings_in"
// output_stream: "EMBEDDINGS:embeddings_out"
// }
//
// Example with timestamp aggregation:
// node {
// calculator: "EmbeddingAggregationCalculator"
// input_stream: "EMBEDDINGS:embeddings_in"
// input_stream: "TIMESTAMPS:timestamps_in"
// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
// }
class EmbeddingAggregationCalculator : public Node {
public:
static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"};
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
"TIMESTAMPS"};
static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{
"EMBEDDINGS"};
static constexpr Output<std::vector<EmbeddingResult>>::Optional
kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"};
MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut,
kTimestampedEmbeddingsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc);
absl::Status Process(CalculatorContext* cc);
private:
bool time_aggregation_enabled_;
std::unordered_map<int64, EmbeddingResult> cached_embeddings_;
};
absl::Status EmbeddingAggregationCalculator::UpdateContract(
CalculatorContract* cc) {
if (kTimestampsIn(cc).IsConnected()) {
RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected());
} else {
RET_CHECK(kEmbeddingsOut(cc).IsConnected());
}
return absl::OkStatus();
}
absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) {
time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected();
return absl::OkStatus();
}
absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) {
if (time_aggregation_enabled_) {
cached_embeddings_[cc->InputTimestamp().Value()] =
std::move(*kEmbeddingsIn(cc));
if (kTimestampsIn(cc).IsEmpty()) {
return absl::OkStatus();
}
auto timestamps = kTimestampsIn(cc).Get();
std::vector<EmbeddingResult> results;
results.reserve(timestamps.size());
for (const auto& timestamp : timestamps) {
auto& result = cached_embeddings_[timestamp.Value()];
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) /
1000);
results.push_back(std::move(result));
cached_embeddings_.erase(timestamp.Value());
}
kTimestampedEmbeddingsOut(cc).Send(std::move(results));
} else {
kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc));
}
RET_CHECK(cached_embeddings_.empty());
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,158 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <optional>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace {
using ::mediapipe::ParseTextProtoOrDie;
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::testing::Pointwise;
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kEmbeddingsInName[] = "embeddings_in";
constexpr char kEmbeddingsOutName[] = "embeddings_out";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps_in";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
Graph graph;
auto& calculator = graph.AddNode("EmbeddingAggregationCalculator");
graph[Input<EmbeddingResult>(kEmbeddingsTag)].SetName(kEmbeddingsInName) >>
calculator.In(kEmbeddingsTag);
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
calculator.In(kTimestampsTag);
calculator.Out(kTimestampedEmbeddingsTag)
.SetName(kTimestampedEmbeddingsName) >>
graph[Output<std::vector<EmbeddingResult>>(
kTimestampedEmbeddingsTag)];
} else {
calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >>
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedEmbeddingsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kEmbeddingsOutName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
absl::Status Send(
const EmbeddingResult& embeddings, int timestamp = 0,
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kEmbeddingsInName, MakePacket<EmbeddingResult>(std::move(embeddings))
.At(Timestamp(timestamp))));
if (aggregation_timestamps.has_value()) {
auto packet = std::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
};
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) {
EmbeddingResult embedding = ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 0 })pb");
MP_ASSERT_OK_AND_ASSIGN(auto poller,
BuildGraph(/*connect_timestamps=*/false));
MP_ASSERT_OK(Send(embedding));
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<EmbeddingResult>(poller));
EXPECT_THAT(result, EqualsProto(embedding));
}
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) {
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
MP_ASSERT_OK(Send(ParseTextProtoOrDie<EmbeddingResult>(R"pb(embeddings {
head_index: 0
})pb")));
MP_ASSERT_OK(Send(
ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 1 })pb"),
/*timestamp=*/1000,
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<std::vector<EmbeddingResult>>(poller));
EXPECT_THAT(results,
Pointwise(EqualsProto(), {ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 0 }
timestamp_ms: 0)pb"),
ParseTextProtoOrDie<EmbeddingResult>(
R"pb(embeddings { head_index: 1 }
timestamp_ms: 1)pb")}));
}
} // namespace
} // namespace mediapipe

View File

@ -30,15 +30,6 @@ cc_library(
], ],
) )
cc_library(
name = "hand_landmarks_detection_result",
hdrs = ["hand_landmarks_detection_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library( cc_library(
name = "category", name = "category",
srcs = ["category.cc"], srcs = ["category.cc"],

View File

@ -82,6 +82,7 @@ cc_library(
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",

View File

@ -56,6 +56,14 @@ using TensorsSource =
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS"; constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Struct holding the different output streams produced by the graph.
struct EmbeddingPostprocessingOutputStreams {
Source<EmbeddingResult> embeddings;
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
};
// Identifies whether or not the model has quantized outputs, and performs // Identifies whether or not the model has quantized outputs, and performs
// sanity checks. // sanity checks.
@ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing(
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator, to convert into // The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
// output is used for results. Otherwise as no timestamp aggregation is
// required the EMBEDDINGS output is used for results.
//
// Outputs: // Outputs:
// EMBEDDING_RESULT - EmbeddingResult // EMBEDDINGS - EmbeddingResult @Optional
// The output EmbeddingResult. // The embedding results aggregated by head. Must be connected if the
// TIMESTAMPS input is not connected, as it signals that timestamp
// aggregation is not required.
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
// The embedding result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
// //
// The recommended way of using this graph is through the GraphBuilder API using // The recommended way of using this graph is through the GraphBuilder API using
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more // the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
// details. // details.
//
// TODO: add support for additional optional "TIMESTAMPS" input for
// embeddings aggregation.
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph { class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
public: public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override { mediapipe::SubgraphContext* sc) override {
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto embedding_result_out, auto output_streams,
BuildEmbeddingPostprocessing( BuildEmbeddingPostprocessing(
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(), sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph)); graph[Input<std::vector<Tensor>>(kTensorsTag)],
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)]; graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
output_streams.timestamped_embeddings >>
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
// //
// options: the on-device EmbeddingPostprocessingGraphOptions // options: the on-device EmbeddingPostprocessingGraphOptions
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess. // tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that should be used to aggregate embedding results.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<EmbeddingResult>> BuildEmbeddingPostprocessing( absl::StatusOr<EmbeddingPostprocessingOutputStreams>
BuildEmbeddingPostprocessing(
const proto::EmbeddingPostprocessingGraphOptions options, const proto::EmbeddingPostprocessingGraphOptions options,
Source<std::vector<Tensor>> tensors_in, Graph& graph) { Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
// If output tensors are quantized, they must be dequantized first. // If output tensors are quantized, they must be dequantized first.
TensorsSource dequantized_tensors(&tensors_in); TensorsSource dequantized_tensors(&tensors_in);
if (options.has_quantized_outputs()) { if (options.has_quantized_outputs()) {
@ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>() .GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
.CopyFrom(options.tensors_to_embeddings_options()); .CopyFrom(options.tensors_to_embeddings_options());
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag); dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
// Adds EmbeddingAggregationCalculator.
GenericNode& aggregation_node =
graph.AddNode("EmbeddingAggregationCalculator");
tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)] >>
aggregation_node.In(kEmbeddingsTag);
timestamps_in >> aggregation_node.In(kTimestampsTag);
// Connects outputs.
return EmbeddingPostprocessingOutputStreams{
/*embeddings=*/aggregation_node[Output<EmbeddingResult>(
kEmbeddingsTag)],
/*timestamped_embeddings=*/aggregation_node
[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]};
} }
}; };
REGISTER_MEDIAPIPE_GRAPH( REGISTER_MEDIAPIPE_GRAPH(

View File

@ -44,12 +44,20 @@ namespace processors {
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator, to convert into // The output tensors of an InferenceCalculator, to convert into
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8. // EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// The collection of the timestamps that this calculator should aggregate.
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
// output is used for results. Otherwise as no timestamp aggregation is
// required the EMBEDDINGS output is used for results.
// Outputs: // Outputs:
// EMBEDDINGS - EmbeddingResult // EMBEDDINGS - EmbeddingResult @Optional
// The output EmbeddingResult. // The embedding results aggregated by head. Must be connected if the
// // TIMESTAMPS input is not connected, as it signals that timestamp
// TODO: add support for additional optional "TIMESTAMPS" input for // aggregation is not required.
// embeddings aggregation. // TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
// The embedding result aggregated by timestamp, then by head. Must be
// connected if the TIMESTAMPS input is connected, as it signals that
// timestamp aggregation is required.
absl::Status ConfigureEmbeddingPostprocessing( absl::Status ConfigureEmbeddingPostprocessing(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const proto::EmbedderOptions& embedder_options, const proto::EmbedderOptions& embedder_options,

View File

@ -20,11 +20,20 @@ limitations under the License.
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/graph_runner.h"
#include "mediapipe/framework/output_stream_poller.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
@ -37,7 +46,12 @@ namespace components {
namespace processors { namespace processors {
namespace { namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
@ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] =
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
constexpr char kTestModelResourcesTag[] = "test_model_resources"; constexpr char kTestModelResourcesTag[] = "test_model_resources";
constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024;
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTensorsName[] = "tensors";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
constexpr char kTimestampsName[] = "timestamps";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
constexpr char kEmbeddingsName[] = "embeddings";
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings";
// Helper function to get ModelResources. // Helper function to get ModelResources.
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel( absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
@ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
has_quantized_outputs: false)pb"))); has_quantized_outputs: false)pb")));
} }
// TODO: add E2E Postprocessing tests once timestamp aggregation is class PostprocessingTest : public tflite_shims::testing::Test {
// supported. protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const proto::EmbedderOptions& options,
bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name));
Graph graph;
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"EmbeddingPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
*model_resources, options,
&postprocessing
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag);
if (connect_timestamps) {
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
kTimestampsName) >>
postprocessing.In(kTimestampsTag);
postprocessing.Out(kTimestampedEmbeddingsTag)
.SetName(kTimestampedEmbeddingsName) >>
graph[Output<std::vector<EmbeddingResult>>(
kTimestampedEmbeddingsTag)];
} else {
postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
}
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
if (connect_timestamps) {
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
kTimestampedEmbeddingsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
ASSIGN_OR_RETURN(auto poller,
calculator_graph_.AddOutputStreamPoller(kEmbeddingsName));
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
return poller;
}
template <typename T>
void AddTensor(
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
const Tensor::QuantizationParameters& quantization_parameters = {}) {
tensors_->emplace_back(element_type,
Tensor::Shape{1, static_cast<int>(tensor.size())},
quantization_parameters);
auto view = tensors_->back().GetCpuWriteView();
T* buffer = view.buffer<T>();
std::copy(tensor.begin(), tensor.end(), buffer);
}
absl::Status Run(
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
int timestamp = 0) {
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
// Reset tensors for future calls.
tensors_ = absl::make_unique<std::vector<Tensor>>();
if (aggregation_timestamps.has_value()) {
auto packet = absl::make_unique<std::vector<Timestamp>>();
for (const auto& timestamp : *aggregation_timestamps) {
packet->emplace_back(Timestamp(timestamp));
}
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
}
return absl::OkStatus();
}
template <typename T>
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
Packet packet;
if (!poller.Next(&packet)) {
return absl::InternalError("Unable to get output packet");
}
auto result = packet.Get<T>();
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
return result;
}
private:
CalculatorGraph calculator_graph_;
std::unique_ptr<std::vector<Tensor>> tensors_ =
absl::make_unique<std::vector<Tensor>>();
};
TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) {
// Build graph.
proto::EmbedderOptions options;
MP_ASSERT_OK_AND_ASSIGN(auto poller,
BuildGraph(kMobileNetV3Embedder, options));
// Build input tensor.
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
tensor[0] = 1.0;
// Send tensor and get results.
AddTensor(tensor, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
// Validate results.
EXPECT_FALSE(results.has_timestamp_ms());
EXPECT_EQ(results.embeddings_size(), 1);
EXPECT_EQ(results.embeddings(0).head_index(), 0);
EXPECT_EQ(results.embeddings(0).head_name(), "feature");
EXPECT_EQ(results.embeddings(0).float_embedding().values_size(),
kMobileNetV3EmbedderEmbeddingSize);
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0);
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0);
}
}
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
// Build graph.
proto::EmbedderOptions options;
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options,
/*connect_timestamps=*/true));
// Build input tensors.
std::vector<float> tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0);
tensor_0[0] = 1.0;
std::vector<float> tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0);
tensor_1[0] = 2.0;
// Send tensors and get results.
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run());
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
MP_ASSERT_OK(Run(
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
/*timestamp=*/1000));
MP_ASSERT_OK_AND_ASSIGN(auto results,
GetResult<std::vector<EmbeddingResult>>(poller));
// Validate results.
EXPECT_EQ(results.size(), 2);
// First timestamp.
EXPECT_EQ(results[0].timestamp_ms(), 0);
EXPECT_EQ(results[0].embeddings(0).head_index(), 0);
EXPECT_EQ(results[0].embeddings(0).head_name(), "feature");
EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(),
kMobileNetV3EmbedderEmbeddingSize);
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0);
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0);
}
// Second timestamp.
EXPECT_EQ(results[1].timestamp_ms(), 1);
EXPECT_EQ(results[1].embeddings(0).head_index(), 0);
EXPECT_EQ(results[1].embeddings(0).head_name(), "feature");
EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(),
kMobileNetV3EmbedderEmbeddingSize);
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0);
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0);
}
}
} // namespace } // namespace
} // namespace processors } // namespace processors

View File

@ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions {
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32). // Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
optional bool has_quantized_outputs = 2; optional bool has_quantized_outputs = 2;
// TODO: add options to control whether timestamp aggregation
// should be used or not.
} }

View File

@ -110,12 +110,22 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "hand_landmarker_result",
hdrs = ["hand_landmarker_result.h"],
deps = [
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
],
)
cc_library( cc_library(
name = "hand_landmarker", name = "hand_landmarker",
srcs = ["hand_landmarker.cc"], srcs = ["hand_landmarker.cc"],
hdrs = ["hand_landmarker.h"], hdrs = ["hand_landmarker.h"],
deps = [ deps = [
":hand_landmarker_graph", ":hand_landmarker_graph",
":hand_landmarker_result",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
@ -124,7 +134,6 @@ cc_library(
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result",
"//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/base_task_api.h"
@ -34,6 +33,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
@ -47,8 +47,6 @@ namespace {
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
hand_landmarker::proto::HandLandmarkerGraphOptions; hand_landmarker::proto::HandLandmarkerGraphOptions;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
constexpr char kHandLandmarkerGraphTypeName[] = constexpr char kHandLandmarkerGraphTypeName[] =
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"; "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
@ -145,7 +143,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
Packet empty_packet = Packet empty_packet =
status_or_packets.value()[kHandLandmarksStreamName]; status_or_packets.value()[kHandLandmarksStreamName];
result_callback( result_callback(
{HandLandmarksDetectionResult()}, image_packet.Get<Image>(), {HandLandmarkerResult()}, image_packet.Get<Image>(),
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
return; return;
} }
@ -173,7 +171,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
std::move(packets_callback)); std::move(packets_callback));
} }
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect( absl::StatusOr<HandLandmarkerResult> HandLandmarker::Detect(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -192,7 +190,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
{kNormRectStreamName, {kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}})); MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) { if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()}; return {HandLandmarkerResult()};
} }
return {{/* handedness= */ return {{/* handedness= */
{output_packets[kHandednessStreamName] {output_packets[kHandednessStreamName]
@ -205,7 +203,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
.Get<std::vector<mediapipe::LandmarkList>>()}}}; .Get<std::vector<mediapipe::LandmarkList>>()}}};
} }
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo( absl::StatusOr<HandLandmarkerResult> HandLandmarker::DetectForVideo(
mediapipe::Image image, int64 timestamp_ms, mediapipe::Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options) { std::optional<core::ImageProcessingOptions> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
@ -227,7 +225,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
MakePacket<NormalizedRect>(std::move(norm_rect)) MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kHandLandmarksStreamName].IsEmpty()) { if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
return {HandLandmarksDetectionResult()}; return {HandLandmarkerResult()};
} }
return { return {
{/* handedness= */ {/* handedness= */

View File

@ -24,12 +24,12 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
@ -70,9 +70,7 @@ struct HandLandmarkerOptions {
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void( std::function<void(absl::StatusOr<HandLandmarkerResult>, const Image&, int64)>
absl::StatusOr<components::containers::HandLandmarksDetectionResult>,
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -92,7 +90,7 @@ struct HandLandmarkerOptions {
// 'y_center', 'width' and 'height' fields is NOT supported and will // 'y_center', 'width' and 'height' fields is NOT supported and will
// result in an invalid argument error being returned. // result in an invalid argument error being returned.
// Outputs: // Outputs:
// HandLandmarksDetectionResult // HandLandmarkerResult
// - The hand landmarks detection results. // - The hand landmarks detection results.
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi { class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
public: public:
@ -129,7 +127,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed // TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented. // after the yuv support is implemented.
absl::StatusOr<components::containers::HandLandmarksDetectionResult> Detect( absl::StatusOr<HandLandmarkerResult> Detect(
Image image, Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
std::nullopt); std::nullopt);
@ -147,10 +145,10 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::HandLandmarksDetectionResult> absl::StatusOr<HandLandmarkerResult> DetectForVideo(
DetectForVideo(Image image, int64 timestamp_ms, Image image, int64 timestamp_ms,
std::optional<core::ImageProcessingOptions> std::optional<core::ImageProcessingOptions> image_processing_options =
image_processing_options = std::nullopt); std::nullopt);
// Sends live image data to perform hand landmarks detection, and the results // Sends live image data to perform hand landmarks detection, and the results
// will be available via the "result_callback" provided in the // will be available via the "result_callback" provided in the
@ -169,7 +167,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
// invalid argument error being returned. // invalid argument error being returned.
// //
// The "result_callback" provides // The "result_callback" provides
// - A vector of HandLandmarksDetectionResult, each is the detected results // - A vector of HandLandmarkerResult, each is the detected results
// for a input frame. // for a input frame.
// - The const reference to the corresponding input image that the hand // - The const reference to the corresponding input image that the hand
// landmarker runs on. Note that the const reference to the image will no // landmarker runs on. Note that the const reference to the image will no

View File

@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ #ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ #define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace vision {
namespace containers { namespace hand_landmarker {
// The hand landmarks detection result from HandLandmarker, where each vector // The hand landmarks detection result from HandLandmarker, where each vector
// element represents a single hand detected in the image. // element represents a single hand detected in the image.
struct HandLandmarksDetectionResult { struct HandLandmarkerResult {
// Classification of handedness. // Classification of handedness.
std::vector<mediapipe::ClassificationList> handedness; std::vector<mediapipe::ClassificationList> handedness;
// Detected hand landmarks in normalized image coordinates. // Detected hand landmarks in normalized image coordinates.
@ -35,9 +35,9 @@ struct HandLandmarksDetectionResult {
std::vector<mediapipe::LandmarkList> hand_world_landmarks; std::vector<mediapipe::LandmarkList> hand_world_landmarks;
}; };
} // namespace containers } // namespace hand_landmarker
} // namespace components } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_ #endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_

View File

@ -32,12 +32,12 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -50,7 +50,6 @@ namespace {
using ::file::Defaults; using ::file::Defaults;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult; using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
@ -95,9 +94,9 @@ LandmarksDetectionResult GetLandmarksDetectionResult(
return result; return result;
} }
HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult( HandLandmarkerResult GetExpectedHandLandmarkerResult(
const std::vector<absl::string_view>& landmarks_file_names) { const std::vector<absl::string_view>& landmarks_file_names) {
HandLandmarksDetectionResult expected_results; HandLandmarkerResult expected_results;
for (const auto& file_name : landmarks_file_names) { for (const auto& file_name : landmarks_file_names) {
const auto landmarks_detection_result = const auto landmarks_detection_result =
GetLandmarksDetectionResult(file_name); GetLandmarksDetectionResult(file_name);
@ -109,9 +108,9 @@ HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
return expected_results; return expected_results;
} }
void ExpectHandLandmarksDetectionResultsCorrect( void ExpectHandLandmarkerResultsCorrect(
const HandLandmarksDetectionResult& actual_results, const HandLandmarkerResult& actual_results,
const HandLandmarksDetectionResult& expected_results) { const HandLandmarkerResult& expected_results) {
const auto& actual_landmarks = actual_results.hand_landmarks; const auto& actual_landmarks = actual_results.hand_landmarks;
const auto& actual_handedness = actual_results.handedness; const auto& actual_handedness = actual_results.handedness;
@ -145,7 +144,7 @@ struct TestParams {
// clockwise. // clockwise.
int rotation; int rotation;
// Expected results from the hand landmarker model output. // Expected results from the hand landmarker model output.
HandLandmarksDetectionResult expected_results; HandLandmarkerResult expected_results;
}; };
class ImageModeTest : public testing::TestWithParam<TestParams> {}; class ImageModeTest : public testing::TestWithParam<TestParams> {};
@ -213,7 +212,7 @@ TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
HandLandmarksDetectionResult hand_landmarker_results; HandLandmarkerResult hand_landmarker_results;
if (GetParam().rotation != 0) { if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation; image_processing_options.rotation_degrees = GetParam().rotation;
@ -224,8 +223,8 @@ TEST_P(ImageModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results, MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->Detect(image)); hand_landmarker->Detect(image));
} }
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results, ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
GetParam().expected_results); GetParam().expected_results);
MP_ASSERT_OK(hand_landmarker->Close()); MP_ASSERT_OK(hand_landmarker->Close());
} }
@ -237,8 +236,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
{kThumbUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUp", /* test_name= */ "LandmarksPointingUp",
@ -246,8 +244,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
{kPointingUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUpRotated", /* test_name= */ "LandmarksPointingUpRotated",
@ -255,7 +252,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90, /* rotation= */ -90,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult(
{kPointingUpRotatedLandmarksFilename}), {kPointingUpRotatedLandmarksFilename}),
}, },
TestParams{ TestParams{
@ -315,7 +312,7 @@ TEST_P(VideoModeTest, Succeeds) {
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
const auto expected_results = GetParam().expected_results; const auto expected_results = GetParam().expected_results;
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results; HandLandmarkerResult hand_landmarker_results;
if (GetParam().rotation != 0) { if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation; image_processing_options.rotation_degrees = GetParam().rotation;
@ -326,8 +323,8 @@ TEST_P(VideoModeTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results, MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
hand_landmarker->DetectForVideo(image, i)); hand_landmarker->DetectForVideo(image, i));
} }
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results, ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
expected_results); expected_results);
} }
MP_ASSERT_OK(hand_landmarker->Close()); MP_ASSERT_OK(hand_landmarker->Close());
} }
@ -340,8 +337,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
{kThumbUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUp", /* test_name= */ "LandmarksPointingUp",
@ -349,8 +345,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
{kPointingUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUpRotated", /* test_name= */ "LandmarksPointingUpRotated",
@ -358,7 +353,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90, /* rotation= */ -90,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult(
{kPointingUpRotatedLandmarksFilename}), {kPointingUpRotatedLandmarksFilename}),
}, },
TestParams{ TestParams{
@ -383,9 +378,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset); JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
options->result_callback = options->result_callback = [](absl::StatusOr<HandLandmarkerResult> results,
[](absl::StatusOr<HandLandmarksDetectionResult> results, const Image& image, int64 timestamp_ms) {};
const Image& image, int64 timestamp_ms) {};
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
@ -416,23 +410,23 @@ TEST_P(LiveStreamModeTest, Succeeds) {
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, GetParam().test_model_file); JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
options->running_mode = core::RunningMode::LIVE_STREAM; options->running_mode = core::RunningMode::LIVE_STREAM;
std::vector<HandLandmarksDetectionResult> hand_landmarker_results; std::vector<HandLandmarkerResult> hand_landmarker_results;
std::vector<std::pair<int, int>> image_sizes; std::vector<std::pair<int, int>> image_sizes;
std::vector<int64> timestamps; std::vector<int64> timestamps;
options->result_callback = options->result_callback = [&hand_landmarker_results, &image_sizes,
[&hand_landmarker_results, &image_sizes, &timestamps]( &timestamps](
absl::StatusOr<HandLandmarksDetectionResult> results, absl::StatusOr<HandLandmarkerResult> results,
const Image& image, int64 timestamp_ms) { const Image& image, int64 timestamp_ms) {
MP_ASSERT_OK(results.status()); MP_ASSERT_OK(results.status());
hand_landmarker_results.push_back(std::move(results.value())); hand_landmarker_results.push_back(std::move(results.value()));
image_sizes.push_back({image.width(), image.height()}); image_sizes.push_back({image.width(), image.height()});
timestamps.push_back(timestamp_ms); timestamps.push_back(timestamp_ms);
}; };
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
HandLandmarker::Create(std::move(options))); HandLandmarker::Create(std::move(options)));
for (int i = 0; i < iterations; ++i) { for (int i = 0; i < iterations; ++i) {
HandLandmarksDetectionResult hand_landmarker_results; HandLandmarkerResult hand_landmarker_results;
if (GetParam().rotation != 0) { if (GetParam().rotation != 0) {
ImageProcessingOptions image_processing_options; ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = GetParam().rotation; image_processing_options.rotation_degrees = GetParam().rotation;
@ -450,8 +444,8 @@ TEST_P(LiveStreamModeTest, Succeeds) {
const auto expected_results = GetParam().expected_results; const auto expected_results = GetParam().expected_results;
for (int i = 0; i < hand_landmarker_results.size(); ++i) { for (int i = 0; i < hand_landmarker_results.size(); ++i) {
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i], ExpectHandLandmarkerResultsCorrect(hand_landmarker_results[i],
expected_results); expected_results);
} }
for (const auto& image_size : image_sizes) { for (const auto& image_size : image_sizes) {
EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.first, image.width());
@ -472,8 +466,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
{kThumbUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUp", /* test_name= */ "LandmarksPointingUp",
@ -481,8 +474,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ 0, /* rotation= */ 0,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
{kPointingUpLandmarksFilename}),
}, },
TestParams{ TestParams{
/* test_name= */ "LandmarksPointingUpRotated", /* test_name= */ "LandmarksPointingUpRotated",
@ -490,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P(
/* test_model_file= */ kHandLandmarkerBundleAsset, /* test_model_file= */ kHandLandmarkerBundleAsset,
/* rotation= */ -90, /* rotation= */ -90,
/* expected_results = */ /* expected_results = */
GetExpectedHandLandmarksDetectionResult( GetExpectedHandLandmarkerResult(
{kPointingUpRotatedLandmarksFilename}), {kPointingUpRotatedLandmarksFilename}),
}, },
TestParams{ TestParams{

View File

@ -142,6 +142,36 @@ android_library(
], ],
) )
android_library(
name = "handlandmarker",
srcs = [
"handlandmarker/HandLandmarker.java",
"handlandmarker/HandLandmarkerResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "handlandmarker/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:classification_java_proto_lite",
"//mediapipe/framework/formats:landmark_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar") load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
mediapipe_tasks_vision_aar( mediapipe_tasks_vision_aar(

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.handlandmarker">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,501 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.handlandmarker;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.handdetector.proto.HandDetectorGraphOptionsProto;
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarkerGraphOptionsProto;
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarksDetectorGraphOptionsProto;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs hand landmarks detection on images.
*
* <p>This API expects a pre-trained hand landmarks model asset bundle. See <TODO link
* to the DevSite documentation page>.
*
* <ul>
* <li>Input image {@link MPImage}
* <ul>
* <li>The image that hand landmarks detection runs on.
* </ul>
* <li>Output HandLandmarkerResult {@link HandLandmarkerResult}
* <ul>
* <li>A HandLandmarkerResult containing hand landmarks.
* </ul>
* </ul>
*/
public final class HandLandmarker extends BaseVisionTaskApi {
private static final String TAG = HandLandmarker.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"LANDMARKS:hand_landmarks",
"WORLD_LANDMARKS:world_hand_landmarks",
"HANDEDNESS:handedness",
"IMAGE:image_out"));
private static final int LANDMARKS_OUT_STREAM_INDEX = 0;
private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1;
private static final int HANDEDNESS_OUT_STREAM_INDEX = 2;
private static final int IMAGE_OUT_STREAM_INDEX = 3;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
/**
* Creates a {@link HandLandmarker} instance from a model file and the default {@link
* HandLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the hand landmarks model with metadata in the assets.
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
*/
public static HandLandmarker createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates a {@link HandLandmarker} instance from a model file and the default {@link
* HandLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the hand landmarks model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
*/
public static HandLandmarker createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates a {@link HandLandmarker} instance from a model buffer and the default {@link
* HandLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
* model.
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
*/
public static HandLandmarker createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates a {@link HandLandmarker} instance from a {@link HandLandmarkerOptions}.
*
* @param context an Android {@link Context}.
* @param landmarkerOptions a {@link HandLandmarkerOptions} instance.
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
*/
public static HandLandmarker createFromOptions(
Context context, HandLandmarkerOptions landmarkerOptions) {
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<HandLandmarkerResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<HandLandmarkerResult, MPImage>() {
@Override
public HandLandmarkerResult convertToTaskResult(List<Packet> packets) {
// If there is no hands detected in the image, just returns empty lists.
if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
return HandLandmarkerResult.create(
new ArrayList<>(),
new ArrayList<>(),
new ArrayList<>(),
packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp());
}
return HandLandmarkerResult.create(
PacketGetter.getProtoVector(
packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()),
PacketGetter.getProtoVector(
packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()),
PacketGetter.getProtoVector(
packets.get(HANDEDNESS_OUT_STREAM_INDEX), ClassificationList.parser()),
packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp());
}
@Override
public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
landmarkerOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<HandLandmarkerOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(landmarkerOptions)
.setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
return new HandLandmarker(runner, landmarkerOptions.runningMode());
}
/**
* Constructor to initialize an {@link HandLandmarker} from a {@link TaskRunner} and a {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private HandLandmarker(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
}
/**
* Performs hand landmarks detection on the provided single image with default image processing
* options, i.e. without any rotation applied. Only use this method when the {@link
* HandLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java doc
* for input image format.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public HandLandmarkerResult detect(MPImage image) {
return detect(image, ImageProcessingOptions.builder().build());
}
/**
* Performs hand landmarks detection on the provided single image. Only use this method when the
* {@link HandLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java
* doc for input image format.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public HandLandmarkerResult detect(
MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions);
return (HandLandmarkerResult) processImageData(image, imageProcessingOptions);
}
/**
* Performs hand landmarks detection on the provided video frame with default image processing
* options, i.e. without any rotation applied. Only use this method when the {@link
* HandLandmarker} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public HandLandmarkerResult detectForVideo(MPImage image, long timestampMs) {
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Performs hand landmarks detection on the provided video frame. Only use this method when the
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public HandLandmarkerResult detectForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
return (HandLandmarkerResult)
processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform hand landmarks detection with default image processing
* options, i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link HandLandmarkerOptions}. Only use this method when the
* {@link HandLandmarker } is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the hand landmarker. The input timestamps must be monotonically increasing.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(MPImage image, long timestampMs) {
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Sends live image data to perform hand landmarks detection, and the results will be available
* via the {@link ResultListener} provided in the {@link HandLandmarkerOptions}. Only use this
* method when the {@link HandLandmarker} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the hand landmarker. The input timestamps must be monotonically increasing.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
}
/** Options for setting up an {@link HandLandmarker}. */
@AutoValue
public abstract static class HandLandmarkerOptions extends TaskOptions {
/** Builder for {@link HandLandmarkerOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the hand landmarker task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the running mode for the hand landmarker task. Default to the image mode. Hand
* landmarker has three modes:
*
* <ul>
* <li>IMAGE: The mode for detecting hand landmarks on single image inputs.
* <li>VIDEO: The mode for detecting hand landmarks on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for for detecting hand landmarks on a live stream of input
* data, such as from camera. In this mode, {@code setResultListener} must be called to
* set up a listener to receive the detection results asynchronously.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode value);
/** Sets the maximum number of hands can be detected by the HandLandmarker. */
public abstract Builder setNumHands(Integer value);
/** Sets minimum confidence score for the hand detection to be considered successful */
public abstract Builder setMinHandDetectionConfidence(Float value);
/** Sets minimum confidence score of hand presence score in the hand landmark detection. */
public abstract Builder setMinHandPresenceConfidence(Float value);
/** Sets the minimum confidence score for the hand tracking to be considered successful. */
public abstract Builder setMinTrackingConfidence(Float value);
/**
* Sets the result listener to receive the detection results asynchronously when the hand
* landmarker is in the live stream mode.
*/
public abstract Builder setResultListener(
ResultListener<HandLandmarkerResult, MPImage> value);
/** Sets an optional error listener. */
public abstract Builder setErrorListener(ErrorListener value);
abstract HandLandmarkerOptions autoBuild();
/**
* Validates and builds the {@link HandLandmarkerOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the hand landmarker is
* in the live stream mode.
*/
public final HandLandmarkerOptions build() {
HandLandmarkerOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The hand landmarker is in the live stream mode, a user-defined result listener"
+ " must be provided in HandLandmarkerOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The hand landmarker is in the image or the video mode, a user-defined result"
+ " listener shouldn't be provided in HandLandmarkerOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<Integer> numHands();
abstract Optional<Float> minHandDetectionConfidence();
abstract Optional<Float> minHandPresenceConfidence();
abstract Optional<Float> minTrackingConfidence();
abstract Optional<ResultListener<HandLandmarkerResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_HandLandmarker_HandLandmarkerOptions.Builder()
.setRunningMode(RunningMode.IMAGE)
.setNumHands(1)
.setMinHandDetectionConfidence(0.5f)
.setMinHandPresenceConfidence(0.5f)
.setMinTrackingConfidence(0.5f);
}
/** Converts a {@link HandLandmarkerOptions} to a {@link CalculatorOptions} protobuf message. */
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder taskOptionsBuilder =
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
.build());
// Setup HandDetectorGraphOptions.
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder
handDetectorGraphOptionsBuilder =
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder();
numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands);
minHandDetectionConfidence()
.ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence);
// Setup HandLandmarkerGraphOptions.
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder
handLandmarksDetectorGraphOptionsBuilder =
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder();
minHandPresenceConfidence()
.ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence);
minTrackingConfidence().ifPresent(taskOptionsBuilder::setMinTrackingConfidence);
taskOptionsBuilder
.setHandDetectorGraphOptions(handDetectorGraphOptionsBuilder.build())
.setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptionsBuilder.build());
return CalculatorOptions.newBuilder()
.setExtension(
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
/**
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
* region-of-interest.
*/
private static void validateImageProcessingOptions(
ImageProcessingOptions imageProcessingOptions) {
if (imageProcessingOptions.regionOfInterest().isPresent()) {
throw new IllegalArgumentException("HandLandmarker doesn't support region-of-interest.");
}
}
}

View File

@ -0,0 +1,109 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.handlandmarker;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto.Landmark;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */
@AutoValue
public abstract class HandLandmarkerResult implements TaskResult {
/**
* Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and
* handedness protobuf messages.
*
* @param landmarksProto a List of {@link NormalizedLandmarkList}
* @param worldLandmarksProto a List of {@link LandmarkList}
* @param handednessesProto a List of {@link ClassificationList}
*/
static HandLandmarkerResult create(
List<NormalizedLandmarkList> landmarksProto,
List<LandmarkList> worldLandmarksProto,
List<ClassificationList> handednessesProto,
long timestampMs) {
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandLandmarks =
new ArrayList<>();
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandWorldLandmarks =
new ArrayList<>();
List<List<Category>> multiHandHandednesses = new ArrayList<>();
for (NormalizedLandmarkList handLandmarksProto : landmarksProto) {
List<com.google.mediapipe.tasks.components.containers.Landmark> handLandmarks =
new ArrayList<>();
multiHandLandmarks.add(handLandmarks);
for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) {
handLandmarks.add(
com.google.mediapipe.tasks.components.containers.Landmark.create(
handLandmarkProto.getX(),
handLandmarkProto.getY(),
handLandmarkProto.getZ(),
true));
}
}
for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) {
List<com.google.mediapipe.tasks.components.containers.Landmark> handWorldLandmarks =
new ArrayList<>();
multiHandWorldLandmarks.add(handWorldLandmarks);
for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) {
handWorldLandmarks.add(
com.google.mediapipe.tasks.components.containers.Landmark.create(
handWorldLandmarkProto.getX(),
handWorldLandmarkProto.getY(),
handWorldLandmarkProto.getZ(),
false));
}
}
for (ClassificationList handednessProto : handednessesProto) {
List<Category> handedness = new ArrayList<>();
multiHandHandednesses.add(handedness);
for (Classification classification : handednessProto.getClassificationList()) {
handedness.add(
Category.create(
classification.getScore(),
classification.getIndex(),
classification.getLabel(),
classification.getDisplayName()));
}
}
return new AutoValue_HandLandmarkerResult(
timestampMs,
Collections.unmodifiableList(multiHandLandmarks),
Collections.unmodifiableList(multiHandWorldLandmarks),
Collections.unmodifiableList(multiHandHandednesses));
}
@Override
public abstract long timestampMs();
/** Hand landmarks of detected hands. */
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>> landmarks();
/** Hand landmarks in world coordniates of detected hands. */
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>>
worldLandmarks();
/** Handedness of detected hands. */
public abstract List<List<Category>> handednesses();
}

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.handlandmarkertest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="handlandmarkertest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.vision.handlandmarkertest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,424 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.handlandmarker;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import android.graphics.RectF;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.common.truth.Correspondence;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Landmark;
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarker.HandLandmarkerOptions;
import java.io.InputStream;
import java.util.Arrays;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link HandLandmarker}. */
@RunWith(Suite.class)
@SuiteClasses({HandLandmarkerTest.General.class, HandLandmarkerTest.RunningModeTest.class})
public class HandLandmarkerTest {
private static final String HAND_LANDMARKER_BUNDLE_ASSET_FILE = "hand_landmarker.task";
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg";
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
private static final String POINTING_UP_ROTATED_LANDMARKS = "pointing_up_rotated_landmarks.pb";
private static final String TAG = "Hand Landmarker Test";
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
private static final int IMAGE_WIDTH = 382;
private static final int IMAGE_HEIGHT = 406;
@RunWith(AndroidJUnit4.class)
public static final class General extends HandLandmarkerTest {
@Test
public void detect_successWithValidModels() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
HandLandmarkerResult actualResult =
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE));
HandLandmarkerResult expectedResult =
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void detect_successWithEmptyResult() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
HandLandmarkerResult actualResult =
handLandmarker.detect(getImageFromAsset(NO_HANDS_IMAGE));
assertThat(actualResult.landmarks()).isEmpty();
assertThat(actualResult.worldLandmarks()).isEmpty();
assertThat(actualResult.handednesses()).isEmpty();
}
@Test
public void detect_successWithNumHands() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(2)
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
HandLandmarkerResult actualResult =
handLandmarker.detect(getImageFromAsset(TWO_HANDS_IMAGE));
assertThat(actualResult.handednesses()).hasSize(2);
}
@Test
public void recognize_successWithRotation() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
HandLandmarkerResult actualResult =
handLandmarker.detect(
getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions);
HandLandmarkerResult expectedResult =
getExpectedHandLandmarkerResult(POINTING_UP_ROTATED_LANDMARKS);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void recognize_failsWithRegionOfInterest() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setNumHands(1)
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ImageProcessingOptions imageProcessingOptions =
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions));
assertThat(exception)
.hasMessageThat()
.contains("HandLandmarker doesn't support region-of-interest");
}
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends HandLandmarkerTest {
@Test
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(mode)
.setResultListener((HandLandmarkerResults, inputImage) -> {})
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener shouldn't be provided");
}
}
}
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
.build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test
public void recognize_failsWithCallingWrongApiInImageMode() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
handLandmarker.detectForVideo(
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
handLandmarker.detectAsync(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
handLandmarker.detectAsync(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((HandLandmarkerResults, inputImage) -> {})
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
handLandmarker.detectForVideo(
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@Test
public void recognize_successWithImageMode() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
HandLandmarkerResult actualResult =
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE));
HandLandmarkerResult expectedResult =
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
@Test
public void recognize_successWithVideoMode() throws Exception {
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
HandLandmarkerResult expectedResult =
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
for (int i = 0; i < 3; i++) {
HandLandmarkerResult actualResult =
handLandmarker.detectForVideo(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i);
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
}
}
@Test
public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception {
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
HandLandmarkerResult expectedResult =
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(actualResult, inputImage) -> {
assertActualResultApproximatelyEqualsToExpectedResult(
actualResult, expectedResult);
assertImageSizeIsExpected(inputImage);
})
.build();
try (HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
handLandmarker.detectAsync(image, /*timestampsMs=*/ 1);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> handLandmarker.detectAsync(image, /*timestampsMs=*/ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
}
}
@Test
public void recognize_successWithLiveSteamMode() throws Exception {
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
HandLandmarkerResult expectedResult =
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
HandLandmarkerOptions options =
HandLandmarkerOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(actualResult, inputImage) -> {
assertActualResultApproximatelyEqualsToExpectedResult(
actualResult, expectedResult);
assertImageSizeIsExpected(inputImage);
})
.build();
try (HandLandmarker handLandmarker =
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
handLandmarker.detectAsync(image, /*timestampsMs=*/ i);
}
}
}
private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
private static HandLandmarkerResult getExpectedHandLandmarkerResult(
String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
LandmarksDetectionResult landmarksDetectionResultProto =
LandmarksDetectionResult.parser().parseFrom(istr);
return HandLandmarkerResult.create(
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()),
Arrays.asList(landmarksDetectionResultProto.getClassifications()),
/*timestampMs=*/ 0);
}
private static void assertActualResultApproximatelyEqualsToExpectedResult(
HandLandmarkerResult actualResult, HandLandmarkerResult expectedResult) {
// Expects to have the same number of hands detected.
assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size());
assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size());
assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size());
// Actual landmarks match expected landmarks.
assertThat(actualResult.landmarks().get(0))
.comparingElementsUsing(
Correspondence.from(
(Correspondence.BinaryPredicate<Landmark, Landmark>)
(actual, expected) -> {
return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
.compare(actual.x(), expected.x())
&& Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
.compare(actual.y(), expected.y());
},
"landmarks approximately equal to"))
.containsExactlyElementsIn(expectedResult.landmarks().get(0));
// Actual handedness matches expected handedness.
Category actualTopHandedness = actualResult.handednesses().get(0).get(0);
Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0);
assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index());
assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName());
}
private static void assertImageSizeIsExpected(MPImage inputImage) {
assertThat(inputImage).isNotNull();
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
}
}

View File

@ -53,7 +53,13 @@ class LandmarksDetectionResult:
def to_pb2(self) -> _LandmarksDetectionResultProto: def to_pb2(self) -> _LandmarksDetectionResultProto:
"""Generates a LandmarksDetectionResult protobuf object.""" """Generates a LandmarksDetectionResult protobuf object."""
landmarks = _NormalizedLandmarkListProto()
classifications = _ClassificationListProto() classifications = _ClassificationListProto()
world_landmarks = _LandmarkListProto()
for landmark in self.landmarks:
landmarks.landmark.append(landmark.to_pb2())
for category in self.categories: for category in self.categories:
classifications.classification.append( classifications.classification.append(
_ClassificationProto( _ClassificationProto(
@ -63,9 +69,9 @@ class LandmarksDetectionResult:
display_name=category.display_name)) display_name=category.display_name))
return _LandmarksDetectionResultProto( return _LandmarksDetectionResultProto(
landmarks=_NormalizedLandmarkListProto(self.landmarks), landmarks=landmarks,
classifications=classifications, classifications=classifications,
world_landmarks=_LandmarkListProto(self.world_landmarks), world_landmarks=world_landmarks,
rect=self.rect.to_pb2()) rect=self.rect.to_pb2())
@classmethod @classmethod
@ -73,9 +79,11 @@ class LandmarksDetectionResult:
def create_from_pb2( def create_from_pb2(
cls, cls,
pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult': pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult':
"""Creates a `LandmarksDetectionResult` object from the given protobuf object. """Creates a `LandmarksDetectionResult` object from the given protobuf object."""
"""
categories = [] categories = []
landmarks = []
world_landmarks = []
for classification in pb2_obj.classifications.classification: for classification in pb2_obj.classifications.classification:
categories.append( categories.append(
category_module.Category( category_module.Category(
@ -83,14 +91,14 @@ class LandmarksDetectionResult:
index=classification.index, index=classification.index,
category_name=classification.label, category_name=classification.label,
display_name=classification.display_name)) display_name=classification.display_name))
for landmark in pb2_obj.landmarks.landmark:
landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
for landmark in pb2_obj.world_landmarks.landmark:
world_landmarks.append(_Landmark.create_from_pb2(landmark))
return LandmarksDetectionResult( return LandmarksDetectionResult(
landmarks=[ landmarks=landmarks,
_NormalizedLandmark.create_from_pb2(landmark)
for landmark in pb2_obj.landmarks.landmark
],
categories=categories, categories=categories,
world_landmarks=[ world_landmarks=world_landmarks,
_Landmark.create_from_pb2(landmark)
for landmark in pb2_obj.world_landmarks.landmark
],
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect)) rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))

View File

@ -12,9 +12,9 @@ py_library(
srcs = [ srcs = [
"metadata_info.py", "metadata_info.py",
], ],
srcs_version = "PY3",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":writer_utils",
"//mediapipe/tasks/metadata:metadata_schema_py", "//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py", "//mediapipe/tasks/metadata:schema_py",
], ],

View File

@ -14,12 +14,14 @@
# ============================================================================== # ==============================================================================
"""Helper classes for common model metadata information.""" """Helper classes for common model metadata information."""
import collections
import csv import csv
import os import os
from typing import List, Optional, Type from typing import List, Optional, Type, Union
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
# Min and max values for UINT8 tensors. # Min and max values for UINT8 tensors.
_MIN_UINT8 = 0 _MIN_UINT8 = 0
@ -267,6 +269,86 @@ class RegexTokenizerMd:
return tokenizer return tokenizer
class BertTokenizerMd:
"""A container for the Bert tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
"""
def __init__(self, vocab_file_path: str):
"""Initializes a BertTokenizerMd object.
Args:
vocab_file_path: path to the vocabulary file.
"""
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Bert tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Bert tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
tokenizer.options.vocabFile = [vocab]
return tokenizer
class SentencePieceTokenizerMd:
"""A container for the sentence piece tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
_SP_MODEL_DESCRIPTION = "The sentence piece model file."
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
" This file is optional during tokenization, while the sentence piece "
"model is mandatory.")
def __init__(self,
sentence_piece_model_path: str,
vocab_file_path: Optional[str] = None):
"""Initializes a SentencePieceTokenizerMd object.
Args:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
self._sentence_piece_model_path = sentence_piece_model_path
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the sentence piece tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the sentence piece tokenizer metadata.
"""
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
sp_model = _metadata_fb.AssociatedFileT()
sp_model.name = self._sentence_piece_model_path
sp_model.description = self._SP_MODEL_DESCRIPTION
tokenizer.options.sentencePieceModel = [sp_model]
if self._vocab_file_path:
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer.options.vocabFile = [vocab]
return tokenizer
class TensorMd: class TensorMd:
"""A container for common tensor metadata information. """A container for common tensor metadata information.
@ -486,6 +568,145 @@ class InputTextTensorMd(TensorMd):
return tensor_metadata return tensor_metadata
def _get_file_paths(files: List[_metadata_fb.AssociatedFileT]) -> List[str]:
"""Gets file paths from a list of associated files."""
if not files:
return []
return [file.name for file in files]
def _get_tokenizer_associated_files(
tokenizer_options: Optional[
Union[_metadata_fb.BertTokenizerOptionsT,
_metadata_fb.SentencePieceTokenizerOptionsT]]
) -> List[str]:
"""Gets a list of associated files packed in the tokenizer_options.
Args:
tokenizer_options: a tokenizer metadata object. Support the following
tokenizer types:
1. BertTokenizerOptions:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
2. SentencePieceTokenizerOptions:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
Returns:
A list of associated files included in tokenizer_options.
"""
if not tokenizer_options:
return []
if isinstance(tokenizer_options, _metadata_fb.BertTokenizerOptionsT):
return _get_file_paths(tokenizer_options.vocabFile)
elif isinstance(tokenizer_options,
_metadata_fb.SentencePieceTokenizerOptionsT):
return _get_file_paths(tokenizer_options.vocabFile) + _get_file_paths(
tokenizer_options.sentencePieceModel)
else:
return []
class BertInputTensorsMd:
"""A container for the input tensor metadata information of Bert models."""
_IDS_NAME = "ids"
_IDS_DESCRIPTION = "Tokenized ids of the input text."
_MASK_NAME = "mask"
_MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
"tokens.")
_SEGMENT_IDS_NAME = "segment_ids"
_SEGMENT_IDS_DESCRIPTION = (
"0 for the first sequence, 1 for the second sequence if exists.")
def __init__(self,
model_buffer: bytearray,
ids_name: str,
mask_name: str,
segment_name: str,
tokenizer_md: Union[None, BertTokenizerMd,
SentencePieceTokenizerMd] = None):
"""Initializes a BertInputTensorsMd object.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata.
Args:
model_buffer: valid buffer of the model file.
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists.
tokenizer_md: information of the tokenizer used to process the input
string, if any. Supported tokenizers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
refer to `InputTensorsMd`.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
[3]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
"""
# Verify that tflite_input_names (read from the model) and
# input_name (collected from users) are aligned.
tflite_input_names = writer_utils.get_input_tensor_names(model_buffer)
input_names = [ids_name, mask_name, segment_name]
if collections.Counter(tflite_input_names) != collections.Counter(
input_names):
raise ValueError(
f"The input tensor names ({input_names}) do not match the tensor "
f"names read from the model ({tflite_input_names}).")
ids_md = TensorMd(
name=self._IDS_NAME,
description=self._IDS_DESCRIPTION,
tensor_name=ids_name)
mask_md = TensorMd(
name=self._MASK_NAME,
description=self._MASK_DESCRIPTION,
tensor_name=mask_name)
segment_ids_md = TensorMd(
name=self._SEGMENT_IDS_NAME,
description=self._SEGMENT_IDS_DESCRIPTION,
tensor_name=segment_name)
self._input_md = [ids_md, mask_md, segment_ids_md]
if not isinstance(tokenizer_md,
(type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
raise ValueError(
f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
)
self._tokenizer_md = tokenizer_md
def create_input_process_unit_metadata(
self) -> List[_metadata_fb.ProcessUnitT]:
"""Creates the input process unit metadata."""
if self._tokenizer_md:
return [self._tokenizer_md.create_metadata()]
else:
return []
def get_tokenizer_associated_files(self) -> List[str]:
"""Gets the associated files that are packed in the tokenizer."""
if self._tokenizer_md:
return _get_tokenizer_associated_files(
self._tokenizer_md.create_metadata().options)
else:
return []
@property
def input_md(self) -> List[TensorMd]:
return self._input_md
class ClassificationTensorMd(TensorMd): class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information. """A container for the classification tensor metadata information.

View File

@ -19,7 +19,7 @@ import csv
import dataclasses import dataclasses
import os import os
import tempfile import tempfile
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import flatbuffers import flatbuffers
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
@ -101,6 +101,34 @@ class RegexTokenizer:
vocab_file_path: str vocab_file_path: str
@dataclasses.dataclass
class BertTokenizer:
"""Parameters of the Bert tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
Attributes:
vocab_file_path: path to the vocabulary file.
"""
vocab_file_path: str
@dataclasses.dataclass
class SentencePieceTokenizer:
"""Parameters of the sentence piece tokenizer tokenizer [1] metadata information.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
Attributes:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
sentence_piece_model_path: str
vocab_file_path: Optional[str] = None
class Labels(object): class Labels(object):
"""Simple container holding classification labels of a particular tensor. """Simple container holding classification labels of a particular tensor.
@ -282,7 +310,9 @@ def _create_metadata_buffer(
model_buffer: bytearray, model_buffer: bytearray,
general_md: Optional[metadata_info.GeneralMd] = None, general_md: Optional[metadata_info.GeneralMd] = None,
input_md: Optional[List[metadata_info.TensorMd]] = None, input_md: Optional[List[metadata_info.TensorMd]] = None,
output_md: Optional[List[metadata_info.TensorMd]] = None) -> bytearray: output_md: Optional[List[metadata_info.TensorMd]] = None,
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None
) -> bytearray:
"""Creates a buffer of the metadata. """Creates a buffer of the metadata.
Args: Args:
@ -290,7 +320,9 @@ def _create_metadata_buffer(
general_md: general information about the model. general_md: general information about the model.
input_md: metadata information of the input tensors. input_md: metadata information of the input tensors.
output_md: metadata information of the output tensors. output_md: metadata information of the output tensors.
input_process_units: a lists of metadata of the input process units [1].
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
Returns: Returns:
A buffer of the metadata. A buffer of the metadata.
@ -325,6 +357,8 @@ def _create_metadata_buffer(
subgraph_metadata = metadata_fb.SubGraphMetadataT() subgraph_metadata = metadata_fb.SubGraphMetadataT()
subgraph_metadata.inputTensorMetadata = input_metadata subgraph_metadata.inputTensorMetadata = input_metadata
subgraph_metadata.outputTensorMetadata = output_metadata subgraph_metadata.outputTensorMetadata = output_metadata
if input_process_units:
subgraph_metadata.inputProcessUnits = input_process_units
# Create the whole model metadata. # Create the whole model metadata.
if general_md is None: if general_md is None:
@ -366,6 +400,7 @@ class MetadataWriter(object):
self._model_buffer = model_buffer self._model_buffer = model_buffer
self._general_md = None self._general_md = None
self._input_mds = [] self._input_mds = []
self._input_process_units = []
self._output_mds = [] self._output_mds = []
self._associated_files = [] self._associated_files = []
self._temp_folder = tempfile.TemporaryDirectory() self._temp_folder = tempfile.TemporaryDirectory()
@ -416,7 +451,7 @@ class MetadataWriter(object):
description: Description of the input tensor. description: Description of the input tensor.
Returns: Returns:
The MetaWriter instance, can be used for chained operation. The MetadataWriter instance, can be used for chained operation.
[1]: [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
@ -448,7 +483,7 @@ class MetadataWriter(object):
description: Description of the input tensor. description: Description of the input tensor.
Returns: Returns:
The MetaWriter instance, can be used for chained operation. The MetadataWriter instance, can be used for chained operation.
[1]: [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500 https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
@ -462,6 +497,63 @@ class MetadataWriter(object):
self._associated_files.append(regex_tokenizer.vocab_file_path) self._associated_files.append(regex_tokenizer.vocab_file_path)
return self return self
def add_bert_text_input(self, tokenizer: Union[BertTokenizer,
SentencePieceTokenizer],
ids_name: str, mask_name: str,
segment_name: str) -> 'MetadataWriter':
"""Adds an metadata for the text input with bert / sentencepiece tokenizer.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata.
Args:
tokenizer: information of the tokenizer used to process the input string,
if any. Supported tokenziers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2].
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists.
Returns:
The MetadataWriter instance, can be used for chained operation.
Raises:
ValueError: if the type tokenizer is not BertTokenizer or
SentencePieceTokenizer.
[1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
"""
if isinstance(tokenizer, BertTokenizer):
tokenizer_md = metadata_info.BertTokenizerMd(
vocab_file_path=tokenizer.vocab_file_path)
elif isinstance(tokenizer, SentencePieceTokenizer):
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
sentence_piece_model_path=tokenizer.sentence_piece_model_path,
vocab_file_path=tokenizer.vocab_file_path)
else:
raise ValueError(
f'The type of tokenizer, {type(tokenizer)}, is unsupported')
bert_input_md = metadata_info.BertInputTensorsMd(
self._model_buffer,
ids_name,
mask_name,
segment_name,
tokenizer_md=tokenizer_md)
self._input_mds.extend(bert_input_md.input_md)
self._associated_files.extend(
bert_input_md.get_tokenizer_associated_files())
self._input_process_units.extend(
bert_input_md.create_input_process_unit_metadata())
return self
def add_classification_output( def add_classification_output(
self, self,
labels: Optional[Labels] = None, labels: Optional[Labels] = None,
@ -546,7 +638,8 @@ class MetadataWriter(object):
model_buffer=self._model_buffer, model_buffer=self._model_buffer,
general_md=self._general_md, general_md=self._general_md,
input_md=self._input_mds, input_md=self._input_mds,
output_md=self._output_mds) output_md=self._output_mds,
input_process_units=self._input_process_units)
populator.load_metadata_buffer(metadata_buffer) populator.load_metadata_buffer(metadata_buffer)
if self._associated_files: if self._associated_files:
populator.load_associated_files(self._associated_files) populator.load_associated_files(self._associated_files)

View File

@ -14,11 +14,18 @@
# ============================================================================== # ==============================================================================
"""Writes metadata and label file to the Text classifier models.""" """Writes metadata and label file to the Text classifier models."""
from typing import Union
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
_MODEL_NAME = "TextClassifier" _MODEL_NAME = "TextClassifier"
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.") _MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
# The input tensor names of models created by Model Maker.
_DEFAULT_ID_NAME = "serving_default_input_word_ids:0"
_DEFAULT_MASK_NAME = "serving_default_input_mask:0"
_DEFAULT_SEGMENT_ID_NAME = "serving_default_input_type_ids:0"
class MetadataWriter(metadata_writer.MetadataWriterBase): class MetadataWriter(metadata_writer.MetadataWriterBase):
"""MetadataWriter to write the metadata into the text classifier.""" """MetadataWriter to write the metadata into the text classifier."""
@ -62,3 +69,51 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
writer.add_regex_text_input(regex_tokenizer) writer.add_regex_text_input(regex_tokenizer)
writer.add_classification_output(labels) writer.add_classification_output(labels)
return cls(writer) return cls(writer)
@classmethod
def create_for_bert_model(
cls,
model_buffer: bytearray,
tokenizer: Union[metadata_writer.BertTokenizer,
metadata_writer.SentencePieceTokenizer],
labels: metadata_writer.Labels,
ids_name: str = _DEFAULT_ID_NAME,
mask_name: str = _DEFAULT_MASK_NAME,
segment_name: str = _DEFAULT_SEGMENT_ID_NAME,
) -> "MetadataWriter":
"""Creates MetadataWriter for models with {Bert/SentencePiece}Tokenizer.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata. The default values come from Model Maker.
Args:
model_buffer: valid buffer of the model file.
tokenizer: information of the tokenizer used to process the input string,
if any. Supported tokenziers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
refer to `create_for_regex_model`.
labels: an instance of Labels helper class used in the output
classification tensor [4].
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with `1` for
real tokens and `0` for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists. [1]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
[2]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
[3]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
[4]:
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
Returns:
A MetadataWriter object.
"""
writer = metadata_writer.MetadataWriter(model_buffer)
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
writer.add_bert_text_input(tokenizer, ids_name, mask_name, segment_name)
writer.add_classification_output(labels)
return cls(writer)

View File

@ -367,6 +367,42 @@ class ScoreThresholdingMdTest(absltest.TestCase):
self.assertEqual(metadata_json, expected_json) self.assertEqual(metadata_json, expected_json)
class BertTokenizerMdTest(absltest.TestCase):
_VOCAB_FILE = "vocab.txt"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "bert_tokenizer_meta.json"))
def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
class SentencePieceTokenizerMdTest(absltest.TestCase):
_VOCAB_FILE = "vocab.txt"
_SP_MODEL = "sp.model"
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, "sentence_piece_tokenizer_meta.json"))
def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
self._SP_MODEL, self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()
metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def _create_dummy_model_metadata_with_tensor( def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata. # Create a dummy model using the tensor metadata.

View File

@ -21,28 +21,64 @@ from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
_TEST_DIR = "mediapipe/tasks/testdata/metadata/" _TEST_DIR = "mediapipe/tasks/testdata/metadata/"
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite") _REGEX_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR + _LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
"movie_review_labels.txt") "movie_review_labels.txt")
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt") _REGEX_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
_DELIM_REGEX_PATTERN = r"[^\w\']+" _DELIM_REGEX_PATTERN = r"[^\w\']+"
_JSON_FILE = test_utils.get_test_data_path("movie_review.json") _REGEX_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
_BERT_MODEL = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_no_metadata.tflite")
_BERT_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR +
"mobilebert_vocab.txt")
_SP_MODEL_FILE = test_utils.get_test_data_path(_TEST_DIR + "30k-clean.model")
_BERT_JSON_FILE = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_with_bert_tokenizer.json")
_SENTENCE_PIECE_JSON_FILE = test_utils.get_test_data_path(
_TEST_DIR + "bert_text_classifier_with_sentence_piece.json")
class TextClassifierTest(absltest.TestCase): class TextClassifierTest(absltest.TestCase):
def test_write_metadata(self,): def test_write_metadata_for_regex_model(self):
with open(_MODEL, "rb") as f: with open(_REGEX_MODEL, "rb") as f:
model_buffer = f.read() model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_regex_model( writer = text_classifier.MetadataWriter.create_for_regex_model(
model_buffer, model_buffer,
regex_tokenizer=metadata_writer.RegexTokenizer( regex_tokenizer=metadata_writer.RegexTokenizer(
delim_regex_pattern=_DELIM_REGEX_PATTERN, delim_regex_pattern=_DELIM_REGEX_PATTERN,
vocab_file_path=_VOCAB_FILE), vocab_file_path=_REGEX_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE)) labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate() _, metadata_json = writer.populate()
with open(_JSON_FILE, "r") as f: with open(_REGEX_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_write_metadata_for_bert(self):
with open(_BERT_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_bert_model(
model_buffer,
tokenizer=metadata_writer.BertTokenizer(_BERT_VOCAB_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_BERT_JSON_FILE, "r") as f:
expected_json = f.read()
self.assertEqual(metadata_json, expected_json)
def test_write_metadata_for_sentence_piece(self):
with open(_BERT_MODEL, "rb") as f:
model_buffer = f.read()
writer = text_classifier.MetadataWriter.create_for_bert_model(
model_buffer,
tokenizer=metadata_writer.SentencePieceTokenizer(_SP_MODEL_FILE),
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
_, metadata_json = writer.populate()
with open(_SENTENCE_PIECE_JSON_FILE, "r") as f:
expected_json = f.read() expected_json = f.read()
self.assertEqual(metadata_json, expected_json) self.assertEqual(metadata_json, expected_json)

View File

@ -94,3 +94,26 @@ py_test(
"//mediapipe/tasks/python/vision/core:vision_task_running_mode", "//mediapipe/tasks/python/vision/core:vision_task_running_mode",
], ],
) )
py_test(
name = "hand_landmarker_test",
srcs = ["hand_landmarker_test.py"],
data = [
"//mediapipe/tasks/testdata/vision:test_images",
"//mediapipe/tasks/testdata/vision:test_models",
"//mediapipe/tasks/testdata/vision:test_protos",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
"//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
"//mediapipe/tasks/python/components/containers:rect",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/test:test_utils",
"//mediapipe/tasks/python/vision:hand_landmarker",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
"@com_google_protobuf//:protobuf_python",
],
)

View File

@ -0,0 +1,428 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for hand landmarker."""
import enum
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from google.protobuf import text_format
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
from mediapipe.tasks.python.components.containers import rect as rect_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import hand_landmarker
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
_BaseOptions = base_options_module.BaseOptions
_Rect = rect_module.Rect
_Landmark = landmark_module.Landmark
_NormalizedLandmark = landmark_module.NormalizedLandmark
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
_Image = image_module.Image
_HandLandmarker = hand_landmarker.HandLandmarker
_HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
_HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_HAND_LANDMARKER_BUNDLE_ASSET_FILE = 'hand_landmarker.task'
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
_TWO_HANDS_IMAGE = 'right_hands.jpg'
_THUMB_UP_IMAGE = 'thumb_up.jpg'
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
_POINTING_UP_IMAGE = 'pointing_up.jpg'
_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt'
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
_POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
_LANDMARKS_ERROR_TOLERANCE = 0.03
_HANDEDNESS_MARGIN = 0.05
def _get_expected_hand_landmarker_result(
file_path: str) -> _HandLandmarkerResult:
landmarks_detection_result_file_path = test_utils.get_test_data_path(
file_path)
with open(landmarks_detection_result_file_path, 'rb') as f:
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
# Use this if a .pb file is available.
# landmarks_detection_result_proto.ParseFromString(f.read())
text_format.Parse(f.read(), landmarks_detection_result_proto)
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
landmarks_detection_result_proto)
return _HandLandmarkerResult(
handedness=[landmarks_detection_result.categories],
hand_landmarks=[landmarks_detection_result.landmarks],
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
class ModelFileType(enum.Enum):
FILE_CONTENT = 1
FILE_NAME = 2
class HandLandmarkerTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.test_image = _Image.create_from_file(
test_utils.get_test_data_path(_THUMB_UP_IMAGE))
self.model_path = test_utils.get_test_data_path(
_HAND_LANDMARKER_BUNDLE_ASSET_FILE)
def _assert_actual_result_approximately_matches_expected_result(
self, actual_result: _HandLandmarkerResult,
expected_result: _HandLandmarkerResult):
# Expects to have the same number of hands detected.
self.assertLen(actual_result.hand_landmarks,
len(expected_result.hand_landmarks))
self.assertLen(actual_result.hand_world_landmarks,
len(expected_result.hand_world_landmarks))
self.assertLen(actual_result.handedness, len(expected_result.handedness))
# Actual landmarks match expected landmarks.
self.assertLen(actual_result.hand_landmarks[0],
len(expected_result.hand_landmarks[0]))
actual_landmarks = actual_result.hand_landmarks[0]
expected_landmarks = expected_result.hand_landmarks[0]
for i, rename_me in enumerate(actual_landmarks):
self.assertAlmostEqual(
rename_me.x,
expected_landmarks[i].x,
delta=_LANDMARKS_ERROR_TOLERANCE)
self.assertAlmostEqual(
rename_me.y,
expected_landmarks[i].y,
delta=_LANDMARKS_ERROR_TOLERANCE)
# Actual handedness matches expected handedness.
actual_top_handedness = actual_result.handedness[0][0]
expected_top_handedness = expected_result.handedness[0][0]
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
self.assertEqual(actual_top_handedness.category_name,
expected_top_handedness.category_name)
self.assertAlmostEqual(
actual_top_handedness.score,
expected_top_handedness.score,
delta=_HANDEDNESS_MARGIN)
def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully.
with _HandLandmarker.create_from_model_path(self.model_path) as landmarker:
self.assertIsInstance(landmarker, _HandLandmarker)
def test_create_from_options_succeeds_with_valid_model_path(self):
# Creates with options containing model file successfully.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _HandLandmarkerOptions(base_options=base_options)
with _HandLandmarker.create_from_options(options) as landmarker:
self.assertIsInstance(landmarker, _HandLandmarker)
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite')
options = _HandLandmarkerOptions(base_options=base_options)
_HandLandmarker.create_from_options(options)
def test_create_from_options_succeeds_with_valid_model_content(self):
# Creates with options containing model content successfully.
with open(self.model_path, 'rb') as f:
base_options = _BaseOptions(model_asset_buffer=f.read())
options = _HandLandmarkerOptions(base_options=base_options)
landmarker = _HandLandmarker.create_from_options(options)
self.assertIsInstance(landmarker, _HandLandmarker)
@parameterized.parameters(
(ModelFileType.FILE_NAME,
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
(ModelFileType.FILE_CONTENT,
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
def test_detect(self, model_file_type, expected_detection_result):
# Creates hand landmarker.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _HandLandmarkerOptions(base_options=base_options)
landmarker = _HandLandmarker.create_from_options(options)
# Performs hand landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
detection_result, expected_detection_result)
# Closes the hand landmarker explicitly when the hand landmarker is not used
# in a context.
landmarker.close()
@parameterized.parameters(
(ModelFileType.FILE_NAME,
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
(ModelFileType.FILE_CONTENT,
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
def test_detect_in_context(self, model_file_type, expected_detection_result):
# Creates hand landmarker.
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_content = f.read()
base_options = _BaseOptions(model_asset_buffer=model_content)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _HandLandmarkerOptions(base_options=base_options)
with _HandLandmarker.create_from_options(options) as landmarker:
# Performs hand landmarks detection on the input.
detection_result = landmarker.detect(self.test_image)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
detection_result, expected_detection_result)
def test_detect_succeeds_with_num_hands(self):
# Creates hand landmarker.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _HandLandmarkerOptions(base_options=base_options, num_hands=2)
with _HandLandmarker.create_from_options(options) as landmarker:
# Load the two hands image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_TWO_HANDS_IMAGE))
# Performs hand landmarks detection on the input.
detection_result = landmarker.detect(test_image)
# Comparing results.
self.assertLen(detection_result.handedness, 2)
def test_detect_succeeds_with_rotation(self):
# Creates hand landmarker.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _HandLandmarkerOptions(base_options=base_options)
with _HandLandmarker.create_from_options(options) as landmarker:
# Load the pointing up rotated image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(_POINTING_UP_ROTATED_IMAGE))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
# Performs hand landmarks detection on the input.
detection_result = landmarker.detect(test_image, image_processing_options)
expected_detection_result = _get_expected_hand_landmarker_result(
_POINTING_UP_ROTATED_LANDMARKS)
# Comparing results.
self._assert_actual_result_approximately_matches_expected_result(
detection_result, expected_detection_result)
def test_detect_fails_with_region_of_interest(self):
# Creates hand landmarker.
base_options = _BaseOptions(model_asset_path=self.model_path)
options = _HandLandmarkerOptions(base_options=base_options)
with self.assertRaisesRegex(
ValueError, "This task doesn't support region-of-interest."):
with _HandLandmarker.create_from_options(options) as landmarker:
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
image_processing_options = _ImageProcessingOptions(
region_of_interest=_Rect(0, 0, 1, 1))
# Attempt to perform hand landmarks detection on the cropped input.
landmarker.detect(self.test_image, image_processing_options)
def test_empty_detection_outputs(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path))
with _HandLandmarker.create_from_options(options) as landmarker:
# Load the image with no hands.
no_hands_test_image = _Image.create_from_file(
test_utils.get_test_data_path(_NO_HANDS_IMAGE))
# Performs hand landmarks detection on the input.
detection_result = landmarker.detect(no_hands_test_image)
self.assertEmpty(detection_result.hand_landmarks)
self.assertEmpty(detection_result.hand_world_landmarks)
self.assertEmpty(detection_result.handedness)
def test_missing_result_callback(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM)
with self.assertRaisesRegex(ValueError,
r'result callback must be provided'):
with _HandLandmarker.create_from_options(options) as unused_landmarker:
pass
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
def test_illegal_result_callback(self, running_mode):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode,
result_callback=mock.MagicMock())
with self.assertRaisesRegex(ValueError,
r'result callback should not be provided'):
with _HandLandmarker.create_from_options(options) as unused_landmarker:
pass
def test_calling_detect_for_video_in_image_mode(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _HandLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_calling_detect_async_in_image_mode(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE)
with _HandLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_calling_detect_in_video_mode(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _HandLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_async_in_video_mode(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _HandLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the live stream mode'):
landmarker.detect_async(self.test_image, 0)
def test_detect_for_video_with_out_of_order_timestamp(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _HandLandmarker.create_from_options(options) as landmarker:
unused_result = landmarker.detect_for_video(self.test_image, 1)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_for_video(self.test_image, 0)
@parameterized.parameters(
(_THUMB_UP_IMAGE, 0,
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
(_POINTING_UP_IMAGE, 0,
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
(_POINTING_UP_ROTATED_IMAGE, -90,
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
def test_detect_for_video(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation)
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO)
with _HandLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
result = landmarker.detect_for_video(test_image, timestamp,
image_processing_options)
if result.hand_landmarks and result.hand_world_landmarks and result.handedness:
self._assert_actual_result_approximately_matches_expected_result(
result, expected_result)
else:
self.assertEqual(result, expected_result)
def test_calling_detect_in_live_stream_mode(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _HandLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the image mode'):
landmarker.detect(self.test_image)
def test_calling_detect_for_video_in_live_stream_mode(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _HandLandmarker.create_from_options(options) as landmarker:
with self.assertRaisesRegex(ValueError,
r'not initialized with the video mode'):
landmarker.detect_for_video(self.test_image, 0)
def test_detect_async_calls_with_illegal_timestamp(self):
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock())
with _HandLandmarker.create_from_options(options) as landmarker:
landmarker.detect_async(self.test_image, 100)
with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'):
landmarker.detect_async(self.test_image, 0)
@parameterized.parameters(
(_THUMB_UP_IMAGE, 0,
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
(_POINTING_UP_IMAGE, 0,
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
(_POINTING_UP_ROTATED_IMAGE, -90,
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
def test_detect_async_calls(self, image_path, rotation, expected_result):
test_image = _Image.create_from_file(
test_utils.get_test_data_path(image_path))
# Set rotation parameters using ImageProcessingOptions.
image_processing_options = _ImageProcessingOptions(
rotation_degrees=rotation)
observed_timestamp_ms = -1
def check_result(result: _HandLandmarkerResult, output_image: _Image,
timestamp_ms: int):
if result.hand_landmarks and result.hand_world_landmarks and result.handedness:
self._assert_actual_result_approximately_matches_expected_result(
result, expected_result)
else:
self.assertEqual(result, expected_result)
self.assertTrue(
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _HandLandmarkerOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result)
with _HandLandmarker.create_from_options(options) as landmarker:
for timestamp in range(0, 300, 30):
landmarker.detect_async(test_image, timestamp, image_processing_options)
if __name__ == '__main__':
absltest.main()

View File

@ -79,6 +79,28 @@ py_library(
], ],
) )
py_library(
name = "image_embedder",
srcs = [
"image_embedder.py",
],
deps = [
"//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
"//mediapipe/tasks/python/components/containers:embedding_result",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info",
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
"//mediapipe/tasks/python/vision/core:image_processing_options",
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
],
)
py_library( py_library(
name = "gesture_recognizer", name = "gesture_recognizer",
srcs = [ srcs = [
@ -104,18 +126,19 @@ py_library(
) )
py_library( py_library(
name = "image_embedder", name = "hand_landmarker",
srcs = [ srcs = [
"image_embedder.py", "hand_landmarker.py",
], ],
deps = [ deps = [
"//mediapipe/framework/formats:classification_py_pb2",
"//mediapipe/framework/formats:landmark_py_pb2",
"//mediapipe/python:_framework_bindings", "//mediapipe/python:_framework_bindings",
"//mediapipe/python:packet_creator", "//mediapipe/python:packet_creator",
"//mediapipe/python:packet_getter", "//mediapipe/python:packet_getter",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2", "//mediapipe/tasks/python/components/containers:category",
"//mediapipe/tasks/python/components/containers:embedding_result", "//mediapipe/tasks/python/components/containers:landmark",
"//mediapipe/tasks/python/components/processors:classifier_options",
"//mediapipe/tasks/python/core:base_options", "//mediapipe/tasks/python/core:base_options",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
"//mediapipe/tasks/python/core:task_info", "//mediapipe/tasks/python/core:task_info",

View File

@ -16,12 +16,17 @@
import mediapipe.tasks.python.vision.core import mediapipe.tasks.python.vision.core
import mediapipe.tasks.python.vision.gesture_recognizer import mediapipe.tasks.python.vision.gesture_recognizer
import mediapipe.tasks.python.vision.hand_landmarker
import mediapipe.tasks.python.vision.image_classifier import mediapipe.tasks.python.vision.image_classifier
import mediapipe.tasks.python.vision.image_segmenter import mediapipe.tasks.python.vision.image_segmenter
import mediapipe.tasks.python.vision.object_detector import mediapipe.tasks.python.vision.object_detector
GestureRecognizer = gesture_recognizer.GestureRecognizer GestureRecognizer = gesture_recognizer.GestureRecognizer
GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult
HandLandmarker = hand_landmarker.HandLandmarker
HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
ImageClassifier = image_classifier.ImageClassifier ImageClassifier = image_classifier.ImageClassifier
ImageClassifierOptions = image_classifier.ImageClassifierOptions ImageClassifierOptions = image_classifier.ImageClassifierOptions
ImageSegmenter = image_segmenter.ImageSegmenter ImageSegmenter = image_segmenter.ImageSegmenter
@ -33,6 +38,7 @@ RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
# Remove unnecessary modules to avoid duplication in API docs. # Remove unnecessary modules to avoid duplication in API docs.
del core del core
del gesture_recognizer del gesture_recognizer
del hand_landmarker
del image_classifier del image_classifier
del image_segmenter del image_segmenter
del object_detector del object_detector

View File

@ -59,7 +59,7 @@ _GESTURE_DEFAULT_INDEX = -1
@dataclasses.dataclass @dataclasses.dataclass
class GestureRecognitionResult: class GestureRecognizerResult:
"""The gesture recognition result from GestureRecognizer, where each vector element represents a single hand detected in the image. """The gesture recognition result from GestureRecognizer, where each vector element represents a single hand detected in the image.
Attributes: Attributes:
@ -79,8 +79,8 @@ class GestureRecognitionResult:
def _build_recognition_result( def _build_recognition_result(
output_packets: Mapping[str, output_packets: Mapping[str,
packet_module.Packet]) -> GestureRecognitionResult: packet_module.Packet]) -> GestureRecognizerResult:
"""Consturcts a `GestureRecognitionResult` from output packets.""" """Consturcts a `GestureRecognizerResult` from output packets."""
gestures_proto_list = packet_getter.get_proto_list( gestures_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_GESTURE_STREAM_NAME]) output_packets[_HAND_GESTURE_STREAM_NAME])
handedness_proto_list = packet_getter.get_proto_list( handedness_proto_list = packet_getter.get_proto_list(
@ -122,23 +122,25 @@ def _build_recognition_result(
for proto in hand_landmarks_proto_list: for proto in hand_landmarks_proto_list:
hand_landmarks = landmark_pb2.NormalizedLandmarkList() hand_landmarks = landmark_pb2.NormalizedLandmarkList()
hand_landmarks.MergeFrom(proto) hand_landmarks.MergeFrom(proto)
hand_landmarks_results.append([ hand_landmarks_list = []
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark) for hand_landmark in hand_landmarks.landmark:
for hand_landmark in hand_landmarks.landmark hand_landmarks_list.append(
]) landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
hand_landmarks_results.append(hand_landmarks_list)
hand_world_landmarks_results = [] hand_world_landmarks_results = []
for proto in hand_world_landmarks_proto_list: for proto in hand_world_landmarks_proto_list:
hand_world_landmarks = landmark_pb2.LandmarkList() hand_world_landmarks = landmark_pb2.LandmarkList()
hand_world_landmarks.MergeFrom(proto) hand_world_landmarks.MergeFrom(proto)
hand_world_landmarks_results.append([ hand_world_landmarks_list = []
landmark_module.Landmark.create_from_pb2(hand_world_landmark) for hand_world_landmark in hand_world_landmarks.landmark:
for hand_world_landmark in hand_world_landmarks.landmark hand_world_landmarks_list.append(
]) landmark_module.Landmark.create_from_pb2(hand_world_landmark))
hand_world_landmarks_results.append(hand_world_landmarks_list)
return GestureRecognitionResult(gesture_results, handedness_results, return GestureRecognizerResult(gesture_results, handedness_results,
hand_landmarks_results, hand_landmarks_results,
hand_world_landmarks_results) hand_world_landmarks_results)
@dataclasses.dataclass @dataclasses.dataclass
@ -183,7 +185,7 @@ class GestureRecognizerOptions:
custom_gesture_classifier_options: Optional[ custom_gesture_classifier_options: Optional[
_ClassifierOptions] = _ClassifierOptions() _ClassifierOptions] = _ClassifierOptions()
result_callback: Optional[Callable[ result_callback: Optional[Callable[
[GestureRecognitionResult, image_module.Image, int], None]] = None [GestureRecognizerResult, image_module.Image, int], None]] = None
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto: def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
@ -264,7 +266,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME] empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
options.result_callback( options.result_callback(
GestureRecognitionResult([], [], [], []), image, GestureRecognizerResult([], [], [], []), image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND) empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
return return
@ -299,7 +301,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
self, self,
image: image_module.Image, image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> GestureRecognitionResult: ) -> GestureRecognizerResult:
"""Performs hand gesture recognition on the given image. """Performs hand gesture recognition on the given image.
Only use this method when the GestureRecognizer is created with the image Only use this method when the GestureRecognizer is created with the image
@ -330,7 +332,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
}) })
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], []) return GestureRecognizerResult([], [], [], [])
return _build_recognition_result(output_packets) return _build_recognition_result(output_packets)
@ -339,7 +341,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
image: image_module.Image, image: image_module.Image,
timestamp_ms: int, timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None image_processing_options: Optional[_ImageProcessingOptions] = None
) -> GestureRecognitionResult: ) -> GestureRecognizerResult:
"""Performs gesture recognition on the provided video frame. """Performs gesture recognition on the provided video frame.
Only use this method when the GestureRecognizer is created with the video Only use this method when the GestureRecognizer is created with the video
@ -374,7 +376,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
}) })
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty(): if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
return GestureRecognitionResult([], [], [], []) return GestureRecognizerResult([], [], [], [])
return _build_recognition_result(output_packets) return _build_recognition_result(output_packets)

View File

@ -0,0 +1,379 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe hand landmarker task."""
import dataclasses
from typing import Callable, Mapping, Optional, List
from mediapipe.framework.formats import classification_pb2
from mediapipe.framework.formats import landmark_pb2
from mediapipe.python import packet_creator
from mediapipe.python import packet_getter
from mediapipe.python._framework_bindings import image as image_module
from mediapipe.python._framework_bindings import packet as packet_module
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
from mediapipe.tasks.python.components.containers import category as category_module
from mediapipe.tasks.python.components.containers import landmark as landmark_module
from mediapipe.tasks.python.core import base_options as base_options_module
from mediapipe.tasks.python.core import task_info as task_info_module
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
_BaseOptions = base_options_module.BaseOptions
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
_RunningMode = running_mode_module.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo
_IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE'
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
_NORM_RECT_TAG = 'NORM_RECT'
_HANDEDNESS_STREAM_NAME = 'handedness'
_HANDEDNESS_TAG = 'HANDEDNESS'
_HAND_LANDMARKS_STREAM_NAME = 'landmarks'
_HAND_LANDMARKS_TAG = 'LANDMARKS'
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class HandLandmarkerResult:
"""The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image.
Attributes:
handedness: Classification of handedness.
hand_landmarks: Detected hand landmarks in normalized image coordinates.
hand_world_landmarks: Detected hand landmarks in world coordinates.
"""
handedness: List[List[category_module.Category]]
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
hand_world_landmarks: List[List[landmark_module.Landmark]]
def _build_landmarker_result(
output_packets: Mapping[str, packet_module.Packet]) -> HandLandmarkerResult:
"""Constructs a `HandLandmarksDetectionResult` from output packets."""
handedness_proto_list = packet_getter.get_proto_list(
output_packets[_HANDEDNESS_STREAM_NAME])
hand_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_LANDMARKS_STREAM_NAME])
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
handedness_results = []
for proto in handedness_proto_list:
handedness_categories = []
handedness_classifications = classification_pb2.ClassificationList()
handedness_classifications.MergeFrom(proto)
for handedness in handedness_classifications.classification:
handedness_categories.append(
category_module.Category(
index=handedness.index,
score=handedness.score,
display_name=handedness.display_name,
category_name=handedness.label))
handedness_results.append(handedness_categories)
hand_landmarks_results = []
for proto in hand_landmarks_proto_list:
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
hand_landmarks.MergeFrom(proto)
hand_landmarks_list = []
for hand_landmark in hand_landmarks.landmark:
hand_landmarks_list.append(
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
hand_landmarks_results.append(hand_landmarks_list)
hand_world_landmarks_results = []
for proto in hand_world_landmarks_proto_list:
hand_world_landmarks = landmark_pb2.LandmarkList()
hand_world_landmarks.MergeFrom(proto)
hand_world_landmarks_list = []
for hand_world_landmark in hand_world_landmarks.landmark:
hand_world_landmarks_list.append(
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
hand_world_landmarks_results.append(hand_world_landmarks_list)
return HandLandmarkerResult(handedness_results, hand_landmarks_results,
hand_world_landmarks_results)
@dataclasses.dataclass
class HandLandmarkerOptions:
"""Options for the hand landmarker task.
Attributes:
base_options: Base options for the hand landmarker task.
running_mode: The running mode of the task. Default to the image mode.
HandLandmarker has three running modes: 1) The image mode for detecting
hand landmarks on single image inputs. 2) The video mode for detecting
hand landmarks on the decoded frames of a video. 3) The live stream mode
for detecting hand landmarks on the live stream of input data, such as
from camera. In this mode, the "result_callback" below must be specified
to receive the detection results asynchronously.
num_hands: The maximum number of hands can be detected by the hand
landmarker.
min_hand_detection_confidence: The minimum confidence score for the hand
detection to be considered successful.
min_hand_presence_confidence: The minimum confidence score of hand presence
score in the hand landmark detection.
min_tracking_confidence: The minimum confidence score for the hand tracking
to be considered successful.
result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode
is set to the live stream mode.
"""
base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE
num_hands: Optional[int] = 1
min_hand_detection_confidence: Optional[float] = 0.5
min_hand_presence_confidence: Optional[float] = 0.5
min_tracking_confidence: Optional[float] = 0.5
result_callback: Optional[Callable[
[HandLandmarkerResult, image_module.Image, int], None]] = None
@doc_controls.do_not_generate_docs
def to_pb2(self) -> _HandLandmarkerGraphOptionsProto:
"""Generates an HandLandmarkerGraphOptions protobuf object."""
base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
# Initialize the hand landmarker options from base options.
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
base_options=base_options_proto)
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
return hand_landmarker_options_proto
class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
"""Class that performs hand landmarks detection on images."""
@classmethod
def create_from_model_path(cls, model_path: str) -> 'HandLandmarker':
"""Creates an `HandLandmarker` object from a TensorFlow Lite model and the default `HandLandmarkerOptions`.
Note that the created `HandLandmarker` instance is in image mode, for
detecting hand landmarks on single image inputs.
Args:
model_path: Path to the model.
Returns:
`HandLandmarker` object that's created from the model file and the
default `HandLandmarkerOptions`.
Raises:
ValueError: If failed to create `HandLandmarker` object from the
provided file such as invalid file path.
RuntimeError: If other types of error occurred.
"""
base_options = _BaseOptions(model_asset_path=model_path)
options = HandLandmarkerOptions(
base_options=base_options, running_mode=_RunningMode.IMAGE)
return cls.create_from_options(options)
@classmethod
def create_from_options(cls,
options: HandLandmarkerOptions) -> 'HandLandmarker':
"""Creates the `HandLandmarker` object from hand landmarker options.
Args:
options: Options for the hand landmarker task.
Returns:
`HandLandmarker` object that's created from `options`.
Raises:
ValueError: If failed to create `HandLandmarker` object from
`HandLandmarkerOptions` such as missing the model.
RuntimeError: If other types of error occurred.
"""
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME]
options.result_callback(
HandLandmarkerResult([], [], []), image,
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
return
hand_landmarks_detection_result = _build_landmarker_result(output_packets)
timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp
options.result_callback(hand_landmarks_detection_result, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME,
input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
],
output_streams=[
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
':'.join([
_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
],
task_options=options)
return cls(
task_info.generate_graph_config(
enable_flow_limiting=options.running_mode ==
_RunningMode.LIVE_STREAM), options.running_mode,
packets_callback if options.result_callback else None)
def detect(
self,
image: image_module.Image,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the given image.
Only use this method when the HandLandmarker is created with the image
running mode.
The image can be of any size with format RGB or RGBA.
TODO: Describes how the input image will be preprocessed after the yuv
support is implemented.
Args:
image: MediaPipe Image.
image_processing_options: Options for image processing.
Returns:
The hand landmarks detection results.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If hand landmarker detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
output_packets = self._process_image_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2())
})
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
return HandLandmarkerResult([], [], [])
return _build_landmarker_result(output_packets)
def detect_for_video(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> HandLandmarkerResult:
"""Performs hand landmarks detection on the provided video frame.
Only use this method when the HandLandmarker is created with the video
running mode.
Only use this method when the HandLandmarker is created with the video
running mode. It's required to provide the video frame's timestamp (in
milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing.
Returns:
The hand landmarks detection results.
Raises:
ValueError: If any of the input arguments is invalid.
RuntimeError: If hand landmarker detection failed to run.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
output_packets = self._process_video_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
return HandLandmarkerResult([], [], [])
return _build_landmarker_result(output_packets)
def detect_async(
self,
image: image_module.Image,
timestamp_ms: int,
image_processing_options: Optional[_ImageProcessingOptions] = None
) -> None:
"""Sends live image data to perform hand landmarks detection.
The results will be available via the "result_callback" provided in the
HandLandmarkerOptions. Only use this method when the HandLandmarker is
created with the live stream running mode.
Only use this method when the HandLandmarker is created with the live
stream running mode. The input timestamps should be monotonically increasing
for adjacent calls of this method. This method will return immediately after
the input image is accepted. The results will be available via the
`result_callback` provided in the `HandLandmarkerOptions`. The
`detect_async` method is designed to process live stream data such as
camera input. To lower the overall latency, hand landmarker may drop the
input images if needed. In other words, it's not guaranteed to have output
per input image.
The `result_callback` provides:
- The hand landmarks detection results.
- The input image that the hand landmarker runs on.
- The input timestamp in milliseconds.
Args:
image: MediaPipe Image.
timestamp_ms: The timestamp of the input image in milliseconds.
image_processing_options: Options for image processing.
Raises:
ValueError: If the current input timestamp is smaller than what the
hand landmarker has already processed.
"""
normalized_rect = self.convert_to_normalized_rect(
image_processing_options, roi_allowed=False)
self._send_live_stream_data({
_IMAGE_IN_STREAM_NAME:
packet_creator.create_image(image).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
_NORM_RECT_STREAM_NAME:
packet_creator.create_proto(normalized_rect.to_pb2()).at(
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
})

View File

@ -23,10 +23,13 @@ package(
) )
mediapipe_files(srcs = [ mediapipe_files(srcs = [
"30k-clean.model",
"bert_text_classifier_no_metadata.tflite",
"mobile_ica_8bit-with-metadata.tflite", "mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite", "mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite", "mobile_ica_8bit-without-model-metadata.tflite",
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite", "mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
"mobilebert_vocab.txt",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v2_1.0_224_quant.tflite", "mobilenet_v2_1.0_224_quant.tflite",
"mobilenet_v2_1.0_224_quant_without_metadata.tflite", "mobilenet_v2_1.0_224_quant_without_metadata.tflite",
@ -60,11 +63,17 @@ exports_files([
"movie_review_labels.txt", "movie_review_labels.txt",
"regex_vocab.txt", "regex_vocab.txt",
"movie_review.json", "movie_review.json",
"bert_tokenizer_meta.json",
"bert_text_classifier_with_sentence_piece.json",
"sentence_piece_tokenizer_meta.json",
"bert_text_classifier_with_bert_tokenizer.json",
]) ])
filegroup( filegroup(
name = "model_files", name = "model_files",
srcs = [ srcs = [
"30k-clean.model",
"bert_text_classifier_no_metadata.tflite",
"mobile_ica_8bit-with-metadata.tflite", "mobile_ica_8bit-with-metadata.tflite",
"mobile_ica_8bit-with-unsupported-metadata-version.tflite", "mobile_ica_8bit-with-unsupported-metadata-version.tflite",
"mobile_ica_8bit-without-model-metadata.tflite", "mobile_ica_8bit-without-model-metadata.tflite",
@ -81,6 +90,9 @@ filegroup(
name = "data_files", name = "data_files",
srcs = [ srcs = [
"associated_file_meta.json", "associated_file_meta.json",
"bert_text_classifier_with_bert_tokenizer.json",
"bert_text_classifier_with_sentence_piece.json",
"bert_tokenizer_meta.json",
"bounding_box_tensor_meta.json", "bounding_box_tensor_meta.json",
"classification_tensor_float_meta.json", "classification_tensor_float_meta.json",
"classification_tensor_uint8_meta.json", "classification_tensor_uint8_meta.json",
@ -96,6 +108,7 @@ filegroup(
"input_text_tensor_default_meta.json", "input_text_tensor_default_meta.json",
"input_text_tensor_meta.json", "input_text_tensor_meta.json",
"labels.txt", "labels.txt",
"mobilebert_vocab.txt",
"mobilenet_v2_1.0_224.json", "mobilenet_v2_1.0_224.json",
"mobilenet_v2_1.0_224_quant.json", "mobilenet_v2_1.0_224_quant.json",
"movie_review.json", "movie_review.json",
@ -105,5 +118,6 @@ filegroup(
"score_calibration_file_meta.json", "score_calibration_file_meta.json",
"score_calibration_tensor_meta.json", "score_calibration_tensor_meta.json",
"score_thresholding_meta.json", "score_thresholding_meta.json",
"sentence_piece_tokenizer_meta.json",
], ],
) )

View File

@ -0,0 +1,84 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "ids",
"description": "Tokenized ids of the input text.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
],
"input_process_units": [
{
"options_type": "BertTokenizerOptions",
"options": {
"vocab_file": [
{
"name": "mobilebert_vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
]
}
],
"min_parser_version": "1.1.0"
}

View File

@ -0,0 +1,83 @@
{
"name": "TextClassifier",
"description": "Classify the input text into a set of known categories.",
"subgraph_metadata": [
{
"input_tensor_metadata": [
{
"name": "ids",
"description": "Tokenized ids of the input text.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "segment_ids",
"description": "0 for the first sequence, 1 for the second sequence if exists.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
},
{
"name": "mask",
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
}
}
],
"output_tensor_metadata": [
{
"name": "score",
"description": "Score of the labels respectively.",
"content": {
"content_properties_type": "FeatureProperties",
"content_properties": {
}
},
"stats": {
"max": [
1.0
],
"min": [
0.0
]
},
"associated_files": [
{
"name": "labels.txt",
"description": "Labels for categories that the model can recognize.",
"type": "TENSOR_AXIS_LABELS"
}
]
}
],
"input_process_units": [
{
"options_type": "SentencePieceTokenizerOptions",
"options": {
"sentencePiece_model": [
{
"name": "30k-clean.model",
"description": "The sentence piece model file."
}
]
}
}
]
}
],
"min_parser_version": "1.1.0"
}

View File

@ -0,0 +1,20 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "BertTokenizerOptions",
"options": {
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors.",
"type": "VOCABULARY"
}
]
}
}
]
}
]
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
{
"subgraph_metadata": [
{
"input_process_units": [
{
"options_type": "SentencePieceTokenizerOptions",
"options": {
"sentencePiece_model": [
{
"name": "sp.model",
"description": "The sentence piece model file."
}
],
"vocab_file": [
{
"name": "vocab.txt",
"description": "Vocabulary file to convert natural language words to embedding vectors. This file is optional during tokenization, while the sentence piece model is mandatory.",
"type": "VOCABULARY"
}
]
}
}
]
}
]
}

View File

@ -143,20 +143,6 @@ filegroup(
], ],
) )
# Gestures related models. Visible to model_maker.
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
filegroup(
name = "test_gesture_models",
srcs = [
"hand_landmark_full.tflite",
"palm_detection_full.tflite",
],
visibility = [
"//mediapipe/model_maker:__subpackages__",
"//mediapipe/tasks:internal",
],
)
filegroup( filegroup(
name = "test_protos", name = "test_protos",
srcs = [ srcs = [

View File

@ -3,9 +3,22 @@
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm") load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle") load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
package(default_visibility = ["//mediapipe/tasks:internal"]) package(default_visibility = ["//mediapipe/tasks:internal"])
mediapipe_files(srcs = [
"wasm/audio_wasm_internal.js",
"wasm/audio_wasm_internal.wasm",
"wasm/text_wasm_internal.js",
"wasm/text_wasm_internal.wasm",
"wasm/vision_wasm_internal.js",
"wasm/vision_wasm_internal.wasm",
])
# Audio # Audio
mediapipe_ts_library( mediapipe_ts_library(
@ -28,15 +41,18 @@ rollup_bundle(
pkg_npm( pkg_npm(
name = "audio_pkg", name = "audio_pkg",
package_name = "__PACKAGE_NAME__", package_name = "@mediapipe/tasks-__NAME__",
srcs = ["package.json"], srcs = ["package.json"],
substitutions = { substitutions = {
"__PACKAGE_NAME__": "@mediapipe/tasks-audio", "__NAME__": "audio",
"__DESCRIPTION__": "MediaPipe Audio Tasks", "__DESCRIPTION__": "MediaPipe Audio Tasks",
"__BUNDLE__": "audio_bundle.js",
}, },
tgz = "audio.tgz", tgz = "audio.tgz",
deps = [":audio_bundle"], deps = [
"wasm/audio_wasm_internal.js",
"wasm/audio_wasm_internal.wasm",
":audio_bundle",
],
) )
# Text # Text
@ -61,15 +77,18 @@ rollup_bundle(
pkg_npm( pkg_npm(
name = "text_pkg", name = "text_pkg",
package_name = "__PACKAGE_NAME__", package_name = "@mediapipe/tasks-__NAME__",
srcs = ["package.json"], srcs = ["package.json"],
substitutions = { substitutions = {
"__PACKAGE_NAME__": "@mediapipe/tasks-text", "__NAME__": "text",
"__DESCRIPTION__": "MediaPipe Text Tasks", "__DESCRIPTION__": "MediaPipe Text Tasks",
"__BUNDLE__": "text_bundle.js",
}, },
tgz = "text.tgz", tgz = "text.tgz",
deps = [":text_bundle"], deps = [
"wasm/text_wasm_internal.js",
"wasm/text_wasm_internal.wasm",
":text_bundle",
],
) )
# Vision # Vision
@ -94,13 +113,16 @@ rollup_bundle(
pkg_npm( pkg_npm(
name = "vision_pkg", name = "vision_pkg",
package_name = "__PACKAGE_NAME__", package_name = "@mediapipe/tasks-__NAME__",
srcs = ["package.json"], srcs = ["package.json"],
substitutions = { substitutions = {
"__PACKAGE_NAME__": "@mediapipe/tasks-vision", "__NAME__": "vision",
"__DESCRIPTION__": "MediaPipe Vision Tasks", "__DESCRIPTION__": "MediaPipe Vision Tasks",
"__BUNDLE__": "vision_bundle.js",
}, },
tgz = "vision.tgz", tgz = "vision_pkg.tgz",
deps = [":vision_bundle"], deps = [
"wasm/vision_wasm_internal.js",
"wasm/vision_wasm_internal.wasm",
":vision_bundle",
],
) )

View File

@ -21,7 +21,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {AudioClassifierOptions} from './audio_classifier_options'; import {AudioClassifierOptions} from './audio_classifier_options';
import {Classifications} from './audio_classifier_result'; import {AudioClassifierResult} from './audio_classifier_result';
const MEDIAPIPE_GRAPH = const MEDIAPIPE_GRAPH =
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph'; 'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH =
// implementation // implementation
const AUDIO_STREAM = 'input_audio'; const AUDIO_STREAM = 'input_audio';
const SAMPLE_RATE_STREAM = 'sample_rate'; const SAMPLE_RATE_STREAM = 'sample_rate';
const CLASSIFICATION_RESULT_STREAM = 'classification_result'; const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/** Performs audio classification. */ /** Performs audio classification. */
export class AudioClassifier extends TaskRunner { export class AudioClassifier extends TaskRunner {
private classifications: Classifications[] = []; private classificationResults: AudioClassifierResult[] = [];
private defaultSampleRate = 48000; private defaultSampleRate = 48000;
private readonly options = new AudioClassifierGraphOptions(); private readonly options = new AudioClassifierGraphOptions();
@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner {
* `48000` if no custom default was set. * `48000` if no custom default was set.
* @return The classification result of the audio datas * @return The classification result of the audio datas
*/ */
classify(audioData: Float32Array, sampleRate?: number): Classifications[] { classify(audioData: Float32Array, sampleRate?: number):
AudioClassifierResult[] {
sampleRate = sampleRate ?? this.defaultSampleRate; sampleRate = sampleRate ?? this.defaultSampleRate;
// Configures the number of samples in the WASM layer. We re-configure the // Configures the number of samples in the WASM layer. We re-configure the
@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner {
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp); this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
this.addAudioToStream(audioData, timestamp); this.addAudioToStream(audioData, timestamp);
this.classifications = []; this.classificationResults = [];
this.finishProcessing(); this.finishProcessing();
return [...this.classifications]; return [...this.classificationResults];
} }
/** /**
* Internal function for converting raw data into a classification, and * Internal function for converting raw data into classification results, and
* adding it to our classfications list. * adding them to our classfication results list.
**/ **/
private addJsAudioClassification(binaryProto: Uint8Array): void { private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void {
const classificationResult = binaryProtos.forEach(binaryProto => {
ClassificationResult.deserializeBinary(binaryProto); const classificationResult =
this.classifications.push( ClassificationResult.deserializeBinary(binaryProto);
...convertFromClassificationResultProto(classificationResult)); this.classificationResults.push(
convertFromClassificationResultProto(classificationResult));
});
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(AUDIO_STREAM); graphConfig.addInputStream(AUDIO_STREAM);
graphConfig.addInputStream(SAMPLE_RATE_STREAM); graphConfig.addInputStream(SAMPLE_RATE_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -198,14 +201,15 @@ export class AudioClassifier extends TaskRunner {
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM); classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM); classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
classifierNode.addOutputStream( classifierNode.addOutputStream(
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM); 'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM);
classifierNode.setOptions(calculatorOptions); classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { this.attachProtoVectorListener(
this.addJsAudioClassification(binaryProto); TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
}); this.addJsAudioClassificationResults(binaryProtos);
});
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -15,4 +15,4 @@
*/ */
export {Category} from '../../../../tasks/web/components/containers/category'; export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';

View File

@ -10,8 +10,8 @@ mediapipe_ts_library(
) )
mediapipe_ts_library( mediapipe_ts_library(
name = "classifications", name = "classification_result",
srcs = ["classifications.d.ts"], srcs = ["classification_result.d.ts"],
deps = [":category"], deps = [":category"],
) )

View File

@ -16,27 +16,14 @@
import {Category} from '../../../../tasks/web/components/containers/category'; import {Category} from '../../../../tasks/web/components/containers/category';
/** List of predicted categories with an optional timestamp. */ /** Classification results for a given classifier head. */
export interface ClassificationEntry { export interface Classifications {
/** /**
* The array of predicted categories, usually sorted by descending scores, * The array of predicted categories, usually sorted by descending scores,
* e.g., from high to low probability. * e.g., from high to low probability.
*/ */
categories: Category[]; categories: Category[];
/**
* The optional timestamp (in milliseconds) associated to the classification
* entry. This is useful for time series use cases, e.g., audio
* classification.
*/
timestampMs?: number;
}
/** Classifications for a given classifier head. */
export interface Classifications {
/** A list of classification entries. */
entries: ClassificationEntry[];
/** /**
* The index of the classifier head these categories refer to. This is * The index of the classifier head these categories refer to. This is
* useful for multi-head models. * useful for multi-head models.
@ -45,7 +32,24 @@ export interface Classifications {
/** /**
* The name of the classifier head, which is the corresponding tensor * The name of the classifier head, which is the corresponding tensor
* metadata name. * metadata name. Defaults to an empty string if there is no such metadata.
*/ */
headName: string; headName: string;
} }
/** Classification results of a model. */
export interface ClassificationResult {
/** The classification results for each head of the model. */
classifications: Classifications[];
/**
* The optional timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results.
*
* This is only used for classification on time series (e.g. audio
* classification). In these use cases, the amount of data to process might
* exceed the maximum size that the model can process: to solve this, the
* input data is split into multiple chunks starting at different timestamps.
*/
timestampMs?: number;
}

View File

@ -17,8 +17,9 @@ mediapipe_ts_library(
name = "classifier_result", name = "classifier_result",
srcs = ["classifier_result.ts"], srcs = ["classifier_result.ts"],
deps = [ deps = [
"//mediapipe/framework/formats:classification_jspb_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
], ],
) )

View File

@ -14,48 +14,46 @@
* limitations under the License. * limitations under the License.
*/ */
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb'; import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
const DEFAULT_INDEX = -1; const DEFAULT_INDEX = -1;
const DEFAULT_SCORE = 0.0; const DEFAULT_SCORE = 0.0;
/** /**
* Converts a ClassificationEntry proto to the ClassificationEntry result * Converts a Classifications proto to a Classifications object.
* type.
*/ */
function convertFromClassificationEntryProto(source: ClassificationEntryProto): function convertFromClassificationsProto(source: ClassificationsProto):
ClassificationEntry { Classifications {
const categories = source.getCategoriesList().map(category => { const categories =
return { source.getClassificationList()?.getClassificationList().map(
index: category.getIndex() ?? DEFAULT_INDEX, classification => {
score: category.getScore() ?? DEFAULT_SCORE, return {
displayName: category.getDisplayName() ?? '', index: classification.getIndex() ?? DEFAULT_INDEX,
categoryName: category.getCategoryName() ?? '', score: classification.getScore() ?? DEFAULT_SCORE,
}; categoryName: classification.getLabel() ?? '',
}); displayName: classification.getDisplayName() ?? '',
};
}) ??
[];
return { return {
categories, categories,
timestampMs: source.getTimestampMs(), headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
headName: source.getHeadName() ?? '',
}; };
} }
/** /**
* Converts a ClassificationResult proto to a list of classifications. * Converts a ClassificationResult proto to a ClassificationResult object.
*/ */
export function convertFromClassificationResultProto( export function convertFromClassificationResultProto(
classificationResult: ClassificationResult) : Classifications[] { source: ClassificationResultProto): ClassificationResult {
const result: Classifications[] = []; const result: ClassificationResult = {
for (const classificationsProto of classifications: source.getClassificationsList().map(
classificationResult.getClassificationsList()) { classififications => convertFromClassificationsProto(classififications))
const classifications: Classifications = { };
entries: classificationsProto.getEntriesList().map( if (source.hasTimestampMs()) {
entry => convertFromClassificationEntryProto(entry)), result.timestampMs = source.getTimestampMs();
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
headName: classificationsProto.getHeadName() ?? '',
};
result.push(classifications);
} }
return result; return result;
} }

View File

@ -1,9 +1,14 @@
{ {
"name": "__PACKAGE_NAME__", "name": "@mediapipe/tasks-__NAME__",
"version": "__VERSION__", "version": "__VERSION__",
"description": "__DESCRIPTION__", "description": "__DESCRIPTION__",
"main": "__BUNDLE__", "main": "__NAME__bundle.js",
"module": "__BUNDLE__", "module": "__NAME__bundle.js",
"exports": {
".": "./__NAME__bundle.js",
"./loader": "./wasm/__NAME__wasm_internal.js",
"./wasm": "./wasm/__NAME__wasm_internal.wasm"
},
"author": "mediapipe@google.com", "author": "mediapipe@google.com",
"license": "Apache-2.0", "license": "Apache-2.0",
"type": "module", "type": "module",

View File

@ -22,7 +22,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {TextClassifierOptions} from './text_classifier_options'; import {TextClassifierOptions} from './text_classifier_options';
import {Classifications} from './text_classifier_result'; import {TextClassifierResult} from './text_classifier_result';
const INPUT_STREAM = 'text_in'; const INPUT_STREAM = 'text_in';
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out'; const CLASSIFICATIONS_STREAM = 'classifications_out';
const TEXT_CLASSIFIER_GRAPH = const TEXT_CLASSIFIER_GRAPH =
'mediapipe.tasks.text.text_classifier.TextClassifierGraph'; 'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH =
/** Performs Natural Language classification. */ /** Performs Natural Language classification. */
export class TextClassifier extends TaskRunner { export class TextClassifier extends TaskRunner {
private classifications: Classifications[] = []; private classificationResult: TextClassifierResult = {classifications: []};
private readonly options = new TextClassifierGraphOptions(); private readonly options = new TextClassifierGraphOptions();
/** /**
@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner {
* @param text The text to process. * @param text The text to process.
* @return The classification result of the text * @return The classification result of the text
*/ */
classify(text: string): Classifications[] { classify(text: string): TextClassifierResult {
// Get classification classes by running our MediaPipe graph. // Get classification result by running our MediaPipe graph.
this.classifications = []; this.classificationResult = {classifications: []};
this.addStringToStream( this.addStringToStream(
text, INPUT_STREAM, /* timestamp= */ performance.now()); text, INPUT_STREAM, /* timestamp= */ performance.now());
this.finishProcessing(); this.finishProcessing();
return [...this.classifications]; return this.classificationResult;
}
// Internal function for converting raw data into a classification, and
// adding it to our classifications list.
private addJsTextClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
console.log(classificationResult.toObject());
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner {
const classifierNode = new CalculatorGraphConfig.Node(); const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH); classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
classifierNode.addInputStream('TEXT:' + INPUT_STREAM); classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
classifierNode.addOutputStream( classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions); classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.addJsTextClassification(binaryProto); this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
}); });
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();

View File

@ -15,4 +15,4 @@
*/ */
export {Category} from '../../../../tasks/web/components/containers/category'; export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';

View File

@ -21,7 +21,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
"//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:category",
"//mediapipe/tasks/web/components/containers:classifications", "//mediapipe/tasks/web/components/containers:classification_result",
"//mediapipe/tasks/web/components/processors:base_options", "//mediapipe/tasks/web/components/processors:base_options",
"//mediapipe/tasks/web/components/processors:classifier_options", "//mediapipe/tasks/web/components/processors:classifier_options",
"//mediapipe/tasks/web/components/processors:classifier_result", "//mediapipe/tasks/web/components/processors:classifier_result",

View File

@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {ImageClassifierOptions} from './image_classifier_options'; import {ImageClassifierOptions} from './image_classifier_options';
import {Classifications} from './image_classifier_result'; import {ImageClassifierResult} from './image_classifier_result';
const IMAGE_CLASSIFIER_GRAPH = const IMAGE_CLASSIFIER_GRAPH =
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph'; 'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
const INPUT_STREAM = 'input_image'; const INPUT_STREAM = 'input_image';
const CLASSIFICATION_RESULT_STREAM = 'classification_result'; const CLASSIFICATIONS_STREAM = 'classifications';
export {ImageSource}; // Used in the public API export {ImageSource}; // Used in the public API
@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API
/** Performs classification on images. */ /** Performs classification on images. */
export class ImageClassifier extends TaskRunner { export class ImageClassifier extends TaskRunner {
private classifications: Classifications[] = []; private classificationResult: ImageClassifierResult = {classifications: []};
private readonly options = new ImageClassifierGraphOptions(); private readonly options = new ImageClassifierGraphOptions();
/** /**
@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner {
* provided, defaults to `performance.now()`. * provided, defaults to `performance.now()`.
* @return The classification result of the image * @return The classification result of the image
*/ */
classify(imageSource: ImageSource, timestamp?: number): Classifications[] { classify(imageSource: ImageSource, timestamp?: number):
// Get classification classes by running our MediaPipe graph. ImageClassifierResult {
this.classifications = []; // Get classification result by running our MediaPipe graph.
this.classificationResult = {classifications: []};
this.addGpuBufferAsImageToStream( this.addGpuBufferAsImageToStream(
imageSource, INPUT_STREAM, timestamp ?? performance.now()); imageSource, INPUT_STREAM, timestamp ?? performance.now());
this.finishProcessing(); this.finishProcessing();
return [...this.classifications]; return this.classificationResult;
}
/**
* Internal function for converting raw data into a classification, and
* adding it to our classfications list.
**/
private addJsImageClassification(binaryProto: Uint8Array): void {
const classificationResult =
ClassificationResult.deserializeBinary(binaryProto);
this.classifications.push(
...convertFromClassificationResultProto(classificationResult));
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
private refreshGraph(): void { private refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(INPUT_STREAM); graphConfig.addInputStream(INPUT_STREAM);
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM); graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner {
const classifierNode = new CalculatorGraphConfig.Node(); const classifierNode = new CalculatorGraphConfig.Node();
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH); classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM); classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
classifierNode.addOutputStream( classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
classifierNode.setOptions(calculatorOptions); classifierNode.setOptions(calculatorOptions);
graphConfig.addNode(classifierNode); graphConfig.addNode(classifierNode);
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => { this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
this.addJsImageClassification(binaryProto); this.classificationResult = convertFromClassificationResultProto(
ClassificationResult.deserializeBinary(binaryProto));
}); });
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();

View File

@ -15,4 +15,4 @@
*/ */
export {Category} from '../../../../tasks/web/components/containers/category'; export {Category} from '../../../../tasks/web/components/containers/category';
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications'; export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';

View File

@ -28,12 +28,36 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"], urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
) )
http_file(
name = "com_google_mediapipe_bert_text_classifier_no_metadata_tflite",
sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_no_metadata.tflite?generation=1667948360250899"],
)
http_file( http_file(
name = "com_google_mediapipe_bert_text_classifier_tflite", name = "com_google_mediapipe_bert_text_classifier_tflite",
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600", sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"], urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
) )
http_file(
name = "com_google_mediapipe_bert_text_classifier_with_bert_tokenizer_json",
sha256 = "49f148a13a4e3b486b1d3c2400e46e5ebd0d375674c0154278b835760e873a95",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_bert_tokenizer.json?generation=1667948363241334"],
)
http_file(
name = "com_google_mediapipe_bert_text_classifier_with_sentence_piece_json",
sha256 = "113091f3892691de57e379387256b2ce0cc18a1b5185af866220a46da8221f26",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_sentence_piece.json?generation=1667948366009530"],
)
http_file(
name = "com_google_mediapipe_bert_tokenizer_meta_json",
sha256 = "116d70c7c3ef413a8bff54ab758f9ed3d6e51fdc5621d8c920ad2f0035831804",
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_tokenizer_meta.json?generation=1667948368809108"],
)
http_file( http_file(
name = "com_google_mediapipe_bounding_box_tensor_meta_json", name = "com_google_mediapipe_bounding_box_tensor_meta_json",
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a", sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
@ -403,7 +427,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_labels_txt", name = "com_google_mediapipe_labels_txt",
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9", sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667888034706429"], urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667892497527642"],
) )
http_file( http_file(
@ -553,13 +577,13 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_movie_review_json", name = "com_google_mediapipe_movie_review_json",
sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e", sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e",
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667888039053188"], urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667892501695336"],
) )
http_file( http_file(
name = "com_google_mediapipe_movie_review_labels_txt", name = "com_google_mediapipe_movie_review_labels_txt",
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a", sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667888041670721"], urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667892504334882"],
) )
http_file( http_file(
@ -703,7 +727,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_regex_vocab_txt", name = "com_google_mediapipe_regex_vocab_txt",
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923", sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667888047885461"], urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667892507770551"],
) )
http_file( http_file(
@ -790,6 +814,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"], urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
) )
http_file(
name = "com_google_mediapipe_sentence_piece_tokenizer_meta_json",
sha256 = "416bfe231710502e4a93e1b1950c0c6e5db49cffb256d241ef3d3f2d0d57718b",
urls = ["https://storage.googleapis.com/mediapipe-assets/sentence_piece_tokenizer_meta.json?generation=1667948375508564"],
)
http_file( http_file(
name = "com_google_mediapipe_speech_16000_hz_mono_wav", name = "com_google_mediapipe_speech_16000_hz_mono_wav",
sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6", sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6",

47
third_party/wasm_files.bzl vendored Normal file
View File

@ -0,0 +1,47 @@
"""
WASM dependencies for MediaPipe.
This file is auto-generated.
"""
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file")
# buildifier: disable=unnamed-macro
def wasm_files():
"""WASM dependencies for MediaPipe."""
http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"],
)
http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_js",
sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"],
)
http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"],
)
http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"],
)
http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"],
)
http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"],
)