Merge branch 'master' into image-embedder-python
This commit is contained in:
commit
0a6e21c212
|
@ -546,3 +546,6 @@ rules_proto_toolchains()
|
|||
|
||||
load("//third_party:external_files.bzl", "external_files")
|
||||
external_files()
|
||||
|
||||
load("//third_party:wasm_files.bzl", "wasm_files")
|
||||
wasm_files()
|
||||
|
|
|
@ -200,3 +200,38 @@ cc_test(
|
|||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedding_aggregation_calculator",
|
||||
srcs = ["embedding_aggregation_calculator.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "embedding_aggregation_calculator_test",
|
||||
srcs = ["embedding_aggregation_calculator_test.cc"],
|
||||
deps = [
|
||||
":embedding_aggregation_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:output_stream_poller",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
|
||||
// Aggregates EmbeddingResult packets into a vector of timestamped
|
||||
// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
|
||||
// needed.
|
||||
//
|
||||
// Inputs:
|
||||
// EMBEDDINGS: EmbeddingResult
|
||||
// The EmbeddingResult packets to aggregate.
|
||||
// TIMESTAMPS: std::vector<Timestamp> @Optional.
|
||||
// The collection of timestamps that this calculator should aggregate. This
|
||||
// stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
|
||||
// will contain the aggregated results. Otherwise as no timestamp
|
||||
// aggregation is required the EMBEDDINGS output is used to pass the inputs
|
||||
// EmbeddingResults unchanged.
|
||||
//
|
||||
// Outputs:
|
||||
// EMBEDDINGS: EmbeddingResult @Optional
|
||||
// The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
|
||||
// input is not connected, as it signals that timestamp aggregation is not
|
||||
// required.
|
||||
// TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
|
||||
// The embedding results aggregated by timestamp. Must be connected if the
|
||||
// TIMESTAMPS input is connected as it signals that timestamp aggregation is
|
||||
// required.
|
||||
//
|
||||
// Example without timestamp aggregation (pass-through):
|
||||
// node {
|
||||
// calculator: "EmbeddingAggregationCalculator"
|
||||
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||
// output_stream: "EMBEDDINGS:embeddings_out"
|
||||
// }
|
||||
//
|
||||
// Example with timestamp aggregation:
|
||||
// node {
|
||||
// calculator: "EmbeddingAggregationCalculator"
|
||||
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||
// input_stream: "TIMESTAMPS:timestamps_in"
|
||||
// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
|
||||
// }
|
||||
class EmbeddingAggregationCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"};
|
||||
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
||||
"TIMESTAMPS"};
|
||||
static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{
|
||||
"EMBEDDINGS"};
|
||||
static constexpr Output<std::vector<EmbeddingResult>>::Optional
|
||||
kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut,
|
||||
kTimestampedEmbeddingsOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc);
|
||||
absl::Status Process(CalculatorContext* cc);
|
||||
|
||||
private:
|
||||
bool time_aggregation_enabled_;
|
||||
std::unordered_map<int64, EmbeddingResult> cached_embeddings_;
|
||||
};
|
||||
|
||||
absl::Status EmbeddingAggregationCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
if (kTimestampsIn(cc).IsConnected()) {
|
||||
RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected());
|
||||
} else {
|
||||
RET_CHECK(kEmbeddingsOut(cc).IsConnected());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) {
|
||||
time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) {
|
||||
if (time_aggregation_enabled_) {
|
||||
cached_embeddings_[cc->InputTimestamp().Value()] =
|
||||
std::move(*kEmbeddingsIn(cc));
|
||||
if (kTimestampsIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
auto timestamps = kTimestampsIn(cc).Get();
|
||||
std::vector<EmbeddingResult> results;
|
||||
results.reserve(timestamps.size());
|
||||
for (const auto& timestamp : timestamps) {
|
||||
auto& result = cached_embeddings_[timestamp.Value()];
|
||||
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) /
|
||||
1000);
|
||||
results.push_back(std::move(result));
|
||||
cached_embeddings_.erase(timestamp.Value());
|
||||
}
|
||||
kTimestampedEmbeddingsOut(cc).Send(std::move(results));
|
||||
} else {
|
||||
kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc));
|
||||
}
|
||||
RET_CHECK(cached_embeddings_.empty());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,158 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::ParseTextProtoOrDie;
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::testing::Pointwise;
|
||||
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kEmbeddingsInName[] = "embeddings_in";
|
||||
constexpr char kEmbeddingsOutName[] = "embeddings_out";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampsName[] = "timestamps_in";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||
|
||||
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
||||
Graph graph;
|
||||
auto& calculator = graph.AddNode("EmbeddingAggregationCalculator");
|
||||
graph[Input<EmbeddingResult>(kEmbeddingsTag)].SetName(kEmbeddingsInName) >>
|
||||
calculator.In(kEmbeddingsTag);
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
calculator.In(kTimestampsTag);
|
||||
calculator.Out(kTimestampedEmbeddingsTag)
|
||||
.SetName(kTimestampedEmbeddingsName) >>
|
||||
graph[Output<std::vector<EmbeddingResult>>(
|
||||
kTimestampedEmbeddingsTag)];
|
||||
} else {
|
||||
calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >>
|
||||
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
if (connect_timestamps) {
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kTimestampedEmbeddingsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kEmbeddingsOutName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
absl::Status Send(
|
||||
const EmbeddingResult& embeddings, int timestamp = 0,
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kEmbeddingsInName, MakePacket<EmbeddingResult>(std::move(embeddings))
|
||||
.At(Timestamp(timestamp))));
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = std::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
};
|
||||
|
||||
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||
EmbeddingResult embedding = ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 0 })pb");
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||
BuildGraph(/*connect_timestamps=*/false));
|
||||
MP_ASSERT_OK(Send(embedding));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<EmbeddingResult>(poller));
|
||||
|
||||
EXPECT_THAT(result, EqualsProto(embedding));
|
||||
}
|
||||
|
||||
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||
MP_ASSERT_OK(Send(ParseTextProtoOrDie<EmbeddingResult>(R"pb(embeddings {
|
||||
head_index: 0
|
||||
})pb")));
|
||||
MP_ASSERT_OK(Send(
|
||||
ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 1 })pb"),
|
||||
/*timestamp=*/1000,
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||
|
||||
EXPECT_THAT(results,
|
||||
Pointwise(EqualsProto(), {ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 0 }
|
||||
timestamp_ms: 0)pb"),
|
||||
ParseTextProtoOrDie<EmbeddingResult>(
|
||||
R"pb(embeddings { head_index: 1 }
|
||||
timestamp_ms: 1)pb")}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -30,15 +30,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarks_detection_result",
|
||||
hdrs = ["hand_landmarks_detection_result.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "category",
|
||||
srcs = ["category.cc"],
|
||||
|
|
|
@ -82,6 +82,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
|
|
|
@ -56,6 +56,14 @@ using TensorsSource =
|
|||
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
|
||||
// Struct holding the different output streams produced by the graph.
|
||||
struct EmbeddingPostprocessingOutputStreams {
|
||||
Source<EmbeddingResult> embeddings;
|
||||
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
|
||||
};
|
||||
|
||||
// Identifies whether or not the model has quantized outputs, and performs
|
||||
// sanity checks.
|
||||
|
@ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing(
|
|||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator, to convert into
|
||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the EMBEDDINGS output is used for results.
|
||||
//
|
||||
// Outputs:
|
||||
// EMBEDDING_RESULT - EmbeddingResult
|
||||
// The output EmbeddingResult.
|
||||
// EMBEDDINGS - EmbeddingResult @Optional
|
||||
// The embedding results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||
// The embedding result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
//
|
||||
// The recommended way of using this graph is through the GraphBuilder API using
|
||||
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
||||
// details.
|
||||
//
|
||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
||||
// embeddings aggregation.
|
||||
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto embedding_result_out,
|
||||
auto output_streams,
|
||||
BuildEmbeddingPostprocessing(
|
||||
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
output_streams.timestamped_embeddings >>
|
||||
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
|||
//
|
||||
// options: the on-device EmbeddingPostprocessingGraphOptions
|
||||
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
|
||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||
// timestamps that should be used to aggregate embedding results.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<EmbeddingResult>> BuildEmbeddingPostprocessing(
|
||||
absl::StatusOr<EmbeddingPostprocessingOutputStreams>
|
||||
BuildEmbeddingPostprocessing(
|
||||
const proto::EmbeddingPostprocessingGraphOptions options,
|
||||
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
||||
Source<std::vector<Tensor>> tensors_in,
|
||||
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||
// If output tensors are quantized, they must be dequantized first.
|
||||
TensorsSource dequantized_tensors(&tensors_in);
|
||||
if (options.has_quantized_outputs()) {
|
||||
|
@ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
|||
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
||||
.CopyFrom(options.tensors_to_embeddings_options());
|
||||
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
|
||||
// Adds EmbeddingAggregationCalculator.
|
||||
GenericNode& aggregation_node =
|
||||
graph.AddNode("EmbeddingAggregationCalculator");
|
||||
tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)] >>
|
||||
aggregation_node.In(kEmbeddingsTag);
|
||||
timestamps_in >> aggregation_node.In(kTimestampsTag);
|
||||
|
||||
// Connects outputs.
|
||||
return EmbeddingPostprocessingOutputStreams{
|
||||
/*embeddings=*/aggregation_node[Output<EmbeddingResult>(
|
||||
kEmbeddingsTag)],
|
||||
/*timestamped_embeddings=*/aggregation_node
|
||||
[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]};
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
|
|
|
@ -44,12 +44,20 @@ namespace processors {
|
|||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator, to convert into
|
||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// The collection of the timestamps that this calculator should aggregate.
|
||||
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||
// output is used for results. Otherwise as no timestamp aggregation is
|
||||
// required the EMBEDDINGS output is used for results.
|
||||
// Outputs:
|
||||
// EMBEDDINGS - EmbeddingResult
|
||||
// The output EmbeddingResult.
|
||||
//
|
||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
||||
// embeddings aggregation.
|
||||
// EMBEDDINGS - EmbeddingResult @Optional
|
||||
// The embedding results aggregated by head. Must be connected if the
|
||||
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||
// aggregation is not required.
|
||||
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||
// The embedding result aggregated by timestamp, then by head. Must be
|
||||
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||
// timestamp aggregation is required.
|
||||
absl::Status ConfigureEmbeddingPostprocessing(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const proto::EmbedderOptions& embedder_options,
|
||||
|
|
|
@ -20,11 +20,20 @@ limitations under the License.
|
|||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/graph_runner.h"
|
||||
#include "mediapipe/framework/output_stream_poller.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
@ -37,7 +46,12 @@ namespace components {
|
|||
namespace processors {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||
|
@ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] =
|
|||
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
||||
|
||||
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
||||
constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024;
|
||||
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTensorsName[] = "tensors";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
constexpr char kTimestampsName[] = "timestamps";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
constexpr char kEmbeddingsName[] = "embeddings";
|
||||
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings";
|
||||
|
||||
// Helper function to get ModelResources.
|
||||
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||
|
@ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
|||
has_quantized_outputs: false)pb")));
|
||||
}
|
||||
|
||||
// TODO: add E2E Postprocessing tests once timestamp aggregation is
|
||||
// supported.
|
||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||
bool connect_timestamps = false) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
Graph graph;
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"EmbeddingPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
|
||||
*model_resources, options,
|
||||
&postprocessing
|
||||
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
if (connect_timestamps) {
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||
kTimestampsName) >>
|
||||
postprocessing.In(kTimestampsTag);
|
||||
postprocessing.Out(kTimestampedEmbeddingsTag)
|
||||
.SetName(kTimestampedEmbeddingsName) >>
|
||||
graph[Output<std::vector<EmbeddingResult>>(
|
||||
kTimestampedEmbeddingsTag)];
|
||||
} else {
|
||||
postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
|
||||
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||
if (connect_timestamps) {
|
||||
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||
kTimestampedEmbeddingsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto poller,
|
||||
calculator_graph_.AddOutputStreamPoller(kEmbeddingsName));
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||
return poller;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddTensor(
|
||||
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
||||
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
||||
tensors_->emplace_back(element_type,
|
||||
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
||||
quantization_parameters);
|
||||
auto view = tensors_->back().GetCpuWriteView();
|
||||
T* buffer = view.buffer<T>();
|
||||
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||
}
|
||||
|
||||
absl::Status Run(
|
||||
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
||||
int timestamp = 0) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||
// Reset tensors for future calls.
|
||||
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||
if (aggregation_timestamps.has_value()) {
|
||||
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
||||
for (const auto& timestamp : *aggregation_timestamps) {
|
||||
packet->emplace_back(Timestamp(timestamp));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||
|
||||
Packet packet;
|
||||
if (!poller.Next(&packet)) {
|
||||
return absl::InternalError("Unable to get output packet");
|
||||
}
|
||||
auto result = packet.Get<T>();
|
||||
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
CalculatorGraph calculator_graph_;
|
||||
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||
absl::make_unique<std::vector<Tensor>>();
|
||||
};
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) {
|
||||
// Build graph.
|
||||
proto::EmbedderOptions options;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||
BuildGraph(kMobileNetV3Embedder, options));
|
||||
// Build input tensor.
|
||||
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||
tensor[0] = 1.0;
|
||||
|
||||
// Send tensor and get results.
|
||||
AddTensor(tensor, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run());
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_FALSE(results.has_timestamp_ms());
|
||||
EXPECT_EQ(results.embeddings_size(), 1);
|
||||
EXPECT_EQ(results.embeddings(0).head_index(), 0);
|
||||
EXPECT_EQ(results.embeddings(0).head_name(), "feature");
|
||||
EXPECT_EQ(results.embeddings(0).float_embedding().values_size(),
|
||||
kMobileNetV3EmbedderEmbeddingSize);
|
||||
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0);
|
||||
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||
// Build graph.
|
||||
proto::EmbedderOptions options;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options,
|
||||
/*connect_timestamps=*/true));
|
||||
// Build input tensors.
|
||||
std::vector<float> tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||
tensor_0[0] = 1.0;
|
||||
std::vector<float> tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||
tensor_1[0] = 2.0;
|
||||
|
||||
// Send tensors and get results.
|
||||
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run());
|
||||
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
||||
MP_ASSERT_OK(Run(
|
||||
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
||||
/*timestamp=*/1000));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||
|
||||
// Validate results.
|
||||
EXPECT_EQ(results.size(), 2);
|
||||
// First timestamp.
|
||||
EXPECT_EQ(results[0].timestamp_ms(), 0);
|
||||
EXPECT_EQ(results[0].embeddings(0).head_index(), 0);
|
||||
EXPECT_EQ(results[0].embeddings(0).head_name(), "feature");
|
||||
EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(),
|
||||
kMobileNetV3EmbedderEmbeddingSize);
|
||||
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0);
|
||||
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0);
|
||||
}
|
||||
// Second timestamp.
|
||||
EXPECT_EQ(results[1].timestamp_ms(), 1);
|
||||
EXPECT_EQ(results[1].embeddings(0).head_index(), 0);
|
||||
EXPECT_EQ(results[1].embeddings(0).head_name(), "feature");
|
||||
EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(),
|
||||
kMobileNetV3EmbedderEmbeddingSize);
|
||||
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0);
|
||||
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace processors
|
||||
|
|
|
@ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions {
|
|||
|
||||
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||
optional bool has_quantized_outputs = 2;
|
||||
|
||||
// TODO: add options to control whether timestamp aggregation
|
||||
// should be used or not.
|
||||
}
|
||||
|
|
|
@ -110,12 +110,22 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarker_result",
|
||||
hdrs = ["hand_landmarker_result.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hand_landmarker",
|
||||
srcs = ["hand_landmarker.cc"],
|
||||
hdrs = ["hand_landmarker.h"],
|
||||
deps = [
|
||||
":hand_landmarker_graph",
|
||||
":hand_landmarker_result",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
|
@ -124,7 +134,6 @@ cc_library(
|
|||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
|
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
@ -34,6 +33,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||
|
||||
|
@ -47,8 +47,6 @@ namespace {
|
|||
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||
hand_landmarker::proto::HandLandmarkerGraphOptions;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
|
||||
|
||||
constexpr char kHandLandmarkerGraphTypeName[] =
|
||||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
|
||||
|
||||
|
@ -145,7 +143,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
|
|||
Packet empty_packet =
|
||||
status_or_packets.value()[kHandLandmarksStreamName];
|
||||
result_callback(
|
||||
{HandLandmarksDetectionResult()}, image_packet.Get<Image>(),
|
||||
{HandLandmarkerResult()}, image_packet.Get<Image>(),
|
||||
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||
return;
|
||||
}
|
||||
|
@ -173,7 +171,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
|
|||
std::move(packets_callback));
|
||||
}
|
||||
|
||||
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
|
||||
absl::StatusOr<HandLandmarkerResult> HandLandmarker::Detect(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -192,7 +190,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
|
|||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
||||
return {HandLandmarksDetectionResult()};
|
||||
return {HandLandmarkerResult()};
|
||||
}
|
||||
return {{/* handedness= */
|
||||
{output_packets[kHandednessStreamName]
|
||||
|
@ -205,7 +203,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
|
|||
.Get<std::vector<mediapipe::LandmarkList>>()}}};
|
||||
}
|
||||
|
||||
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
|
||||
absl::StatusOr<HandLandmarkerResult> HandLandmarker::DetectForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
|
@ -227,7 +225,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
|
|||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
||||
return {HandLandmarksDetectionResult()};
|
||||
return {HandLandmarkerResult()};
|
||||
}
|
||||
return {
|
||||
{/* handedness= */
|
||||
|
|
|
@ -24,12 +24,12 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
@ -70,9 +70,7 @@ struct HandLandmarkerOptions {
|
|||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult>,
|
||||
const Image&, int64)>
|
||||
std::function<void(absl::StatusOr<HandLandmarkerResult>, const Image&, int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
|
@ -92,7 +90,7 @@ struct HandLandmarkerOptions {
|
|||
// 'y_center', 'width' and 'height' fields is NOT supported and will
|
||||
// result in an invalid argument error being returned.
|
||||
// Outputs:
|
||||
// HandLandmarksDetectionResult
|
||||
// HandLandmarkerResult
|
||||
// - The hand landmarks detection results.
|
||||
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
||||
public:
|
||||
|
@ -129,7 +127,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult> Detect(
|
||||
absl::StatusOr<HandLandmarkerResult> Detect(
|
||||
Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
@ -147,10 +145,10 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult>
|
||||
DetectForVideo(Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
image_processing_options = std::nullopt);
|
||||
absl::StatusOr<HandLandmarkerResult> DetectForVideo(
|
||||
Image image, int64 timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Sends live image data to perform hand landmarks detection, and the results
|
||||
// will be available via the "result_callback" provided in the
|
||||
|
@ -169,7 +167,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
|||
// invalid argument error being returned.
|
||||
//
|
||||
// The "result_callback" provides
|
||||
// - A vector of HandLandmarksDetectionResult, each is the detected results
|
||||
// - A vector of HandLandmarkerResult, each is the detected results
|
||||
// for a input frame.
|
||||
// - The const reference to the corresponding input image that the hand
|
||||
// landmarker runs on. Note that the const reference to the image will no
|
||||
|
|
|
@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace containers {
|
||||
namespace vision {
|
||||
namespace hand_landmarker {
|
||||
|
||||
// The hand landmarks detection result from HandLandmarker, where each vector
|
||||
// element represents a single hand detected in the image.
|
||||
struct HandLandmarksDetectionResult {
|
||||
struct HandLandmarkerResult {
|
||||
// Classification of handedness.
|
||||
std::vector<mediapipe::ClassificationList> handedness;
|
||||
// Detected hand landmarks in normalized image coordinates.
|
||||
|
@ -35,9 +35,9 @@ struct HandLandmarksDetectionResult {
|
|||
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
||||
};
|
||||
|
||||
} // namespace containers
|
||||
} // namespace components
|
||||
} // namespace hand_landmarker
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
|
|
@ -32,12 +32,12 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
|
@ -50,7 +50,6 @@ namespace {
|
|||
|
||||
using ::file::Defaults;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||
|
@ -95,9 +94,9 @@ LandmarksDetectionResult GetLandmarksDetectionResult(
|
|||
return result;
|
||||
}
|
||||
|
||||
HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
|
||||
HandLandmarkerResult GetExpectedHandLandmarkerResult(
|
||||
const std::vector<absl::string_view>& landmarks_file_names) {
|
||||
HandLandmarksDetectionResult expected_results;
|
||||
HandLandmarkerResult expected_results;
|
||||
for (const auto& file_name : landmarks_file_names) {
|
||||
const auto landmarks_detection_result =
|
||||
GetLandmarksDetectionResult(file_name);
|
||||
|
@ -109,9 +108,9 @@ HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
|
|||
return expected_results;
|
||||
}
|
||||
|
||||
void ExpectHandLandmarksDetectionResultsCorrect(
|
||||
const HandLandmarksDetectionResult& actual_results,
|
||||
const HandLandmarksDetectionResult& expected_results) {
|
||||
void ExpectHandLandmarkerResultsCorrect(
|
||||
const HandLandmarkerResult& actual_results,
|
||||
const HandLandmarkerResult& expected_results) {
|
||||
const auto& actual_landmarks = actual_results.hand_landmarks;
|
||||
const auto& actual_handedness = actual_results.handedness;
|
||||
|
||||
|
@ -145,7 +144,7 @@ struct TestParams {
|
|||
// clockwise.
|
||||
int rotation;
|
||||
// Expected results from the hand landmarker model output.
|
||||
HandLandmarksDetectionResult expected_results;
|
||||
HandLandmarkerResult expected_results;
|
||||
};
|
||||
|
||||
class ImageModeTest : public testing::TestWithParam<TestParams> {};
|
||||
|
@ -213,7 +212,7 @@ TEST_P(ImageModeTest, Succeeds) {
|
|||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
HandLandmarksDetectionResult hand_landmarker_results;
|
||||
HandLandmarkerResult hand_landmarker_results;
|
||||
if (GetParam().rotation != 0) {
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||
|
@ -224,8 +223,8 @@ TEST_P(ImageModeTest, Succeeds) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
||||
hand_landmarker->Detect(image));
|
||||
}
|
||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
|
||||
GetParam().expected_results);
|
||||
ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
|
||||
GetParam().expected_results);
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
}
|
||||
|
||||
|
@ -237,8 +236,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kThumbUpLandmarksFilename}),
|
||||
GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUp",
|
||||
|
@ -246,8 +244,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpLandmarksFilename}),
|
||||
GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUpRotated",
|
||||
|
@ -255,7 +252,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ -90,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
GetExpectedHandLandmarkerResult(
|
||||
{kPointingUpRotatedLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
|
@ -315,7 +312,7 @@ TEST_P(VideoModeTest, Succeeds) {
|
|||
HandLandmarker::Create(std::move(options)));
|
||||
const auto expected_results = GetParam().expected_results;
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
HandLandmarksDetectionResult hand_landmarker_results;
|
||||
HandLandmarkerResult hand_landmarker_results;
|
||||
if (GetParam().rotation != 0) {
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||
|
@ -326,8 +323,8 @@ TEST_P(VideoModeTest, Succeeds) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
||||
hand_landmarker->DetectForVideo(image, i));
|
||||
}
|
||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
|
||||
expected_results);
|
||||
ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
|
||||
expected_results);
|
||||
}
|
||||
MP_ASSERT_OK(hand_landmarker->Close());
|
||||
}
|
||||
|
@ -340,8 +337,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kThumbUpLandmarksFilename}),
|
||||
GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUp",
|
||||
|
@ -349,8 +345,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpLandmarksFilename}),
|
||||
GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUpRotated",
|
||||
|
@ -358,7 +353,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ -90,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
GetExpectedHandLandmarkerResult(
|
||||
{kPointingUpRotatedLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
|
@ -383,9 +378,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
options->result_callback =
|
||||
[](absl::StatusOr<HandLandmarksDetectionResult> results,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
options->result_callback = [](absl::StatusOr<HandLandmarkerResult> results,
|
||||
const Image& image, int64 timestamp_ms) {};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
|
@ -416,23 +410,23 @@ TEST_P(LiveStreamModeTest, Succeeds) {
|
|||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
|
||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||
std::vector<HandLandmarksDetectionResult> hand_landmarker_results;
|
||||
std::vector<HandLandmarkerResult> hand_landmarker_results;
|
||||
std::vector<std::pair<int, int>> image_sizes;
|
||||
std::vector<int64> timestamps;
|
||||
options->result_callback =
|
||||
[&hand_landmarker_results, &image_sizes, ×tamps](
|
||||
absl::StatusOr<HandLandmarksDetectionResult> results,
|
||||
const Image& image, int64 timestamp_ms) {
|
||||
MP_ASSERT_OK(results.status());
|
||||
hand_landmarker_results.push_back(std::move(results.value()));
|
||||
image_sizes.push_back({image.width(), image.height()});
|
||||
timestamps.push_back(timestamp_ms);
|
||||
};
|
||||
options->result_callback = [&hand_landmarker_results, &image_sizes,
|
||||
×tamps](
|
||||
absl::StatusOr<HandLandmarkerResult> results,
|
||||
const Image& image, int64 timestamp_ms) {
|
||||
MP_ASSERT_OK(results.status());
|
||||
hand_landmarker_results.push_back(std::move(results.value()));
|
||||
image_sizes.push_back({image.width(), image.height()});
|
||||
timestamps.push_back(timestamp_ms);
|
||||
};
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||
HandLandmarker::Create(std::move(options)));
|
||||
for (int i = 0; i < iterations; ++i) {
|
||||
HandLandmarksDetectionResult hand_landmarker_results;
|
||||
HandLandmarkerResult hand_landmarker_results;
|
||||
if (GetParam().rotation != 0) {
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||
|
@ -450,8 +444,8 @@ TEST_P(LiveStreamModeTest, Succeeds) {
|
|||
|
||||
const auto expected_results = GetParam().expected_results;
|
||||
for (int i = 0; i < hand_landmarker_results.size(); ++i) {
|
||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i],
|
||||
expected_results);
|
||||
ExpectHandLandmarkerResultsCorrect(hand_landmarker_results[i],
|
||||
expected_results);
|
||||
}
|
||||
for (const auto& image_size : image_sizes) {
|
||||
EXPECT_EQ(image_size.first, image.width());
|
||||
|
@ -472,8 +466,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kThumbUpLandmarksFilename}),
|
||||
GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUp",
|
||||
|
@ -481,8 +474,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ 0,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
{kPointingUpLandmarksFilename}),
|
||||
GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
/* test_name= */ "LandmarksPointingUpRotated",
|
||||
|
@ -490,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||
/* rotation= */ -90,
|
||||
/* expected_results = */
|
||||
GetExpectedHandLandmarksDetectionResult(
|
||||
GetExpectedHandLandmarkerResult(
|
||||
{kPointingUpRotatedLandmarksFilename}),
|
||||
},
|
||||
TestParams{
|
||||
|
|
|
@ -142,6 +142,36 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "handlandmarker",
|
||||
srcs = [
|
||||
"handlandmarker/HandLandmarker.java",
|
||||
"handlandmarker/HandLandmarkerResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "handlandmarker/AndroidManifest.xml",
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
|
||||
|
||||
mediapipe_tasks_vision_aar(
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.handlandmarker">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,501 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.handdetector.proto.HandDetectorGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarkerGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarksDetectorGraphOptionsProto;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs hand landmarks detection on images.
|
||||
*
|
||||
* <p>This API expects a pre-trained hand landmarks model asset bundle. See <TODO link
|
||||
* to the DevSite documentation page>.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input image {@link MPImage}
|
||||
* <ul>
|
||||
* <li>The image that hand landmarks detection runs on.
|
||||
* </ul>
|
||||
* <li>Output HandLandmarkerResult {@link HandLandmarkerResult}
|
||||
* <ul>
|
||||
* <li>A HandLandmarkerResult containing hand landmarks.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*/
|
||||
public final class HandLandmarker extends BaseVisionTaskApi {
|
||||
private static final String TAG = HandLandmarker.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"LANDMARKS:hand_landmarks",
|
||||
"WORLD_LANDMARKS:world_hand_landmarks",
|
||||
"HANDEDNESS:handedness",
|
||||
"IMAGE:image_out"));
|
||||
private static final int LANDMARKS_OUT_STREAM_INDEX = 0;
|
||||
private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1;
|
||||
private static final int HANDEDNESS_OUT_STREAM_INDEX = 2;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 3;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
|
||||
|
||||
/**
|
||||
* Creates a {@link HandLandmarker} instance from a model file and the default {@link
|
||||
* HandLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the hand landmarks model with metadata in the assets.
|
||||
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||
*/
|
||||
public static HandLandmarker createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link HandLandmarker} instance from a model file and the default {@link
|
||||
* HandLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the hand landmarks model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||
*/
|
||||
public static HandLandmarker createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link HandLandmarker} instance from a model buffer and the default {@link
|
||||
* HandLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
|
||||
* model.
|
||||
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||
*/
|
||||
public static HandLandmarker createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||
return createFromOptions(
|
||||
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link HandLandmarker} instance from a {@link HandLandmarkerOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param landmarkerOptions a {@link HandLandmarkerOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||
*/
|
||||
public static HandLandmarker createFromOptions(
|
||||
Context context, HandLandmarkerOptions landmarkerOptions) {
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<HandLandmarkerResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<HandLandmarkerResult, MPImage>() {
|
||||
@Override
|
||||
public HandLandmarkerResult convertToTaskResult(List<Packet> packets) {
|
||||
// If there is no hands detected in the image, just returns empty lists.
|
||||
if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return HandLandmarkerResult.create(
|
||||
new ArrayList<>(),
|
||||
new ArrayList<>(),
|
||||
new ArrayList<>(),
|
||||
packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp());
|
||||
}
|
||||
return HandLandmarkerResult.create(
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()),
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()),
|
||||
PacketGetter.getProtoVector(
|
||||
packets.get(HANDEDNESS_OUT_STREAM_INDEX), ClassificationList.parser()),
|
||||
packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp());
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
|
||||
landmarkerOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<HandLandmarkerOptions>builder()
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(landmarkerOptions)
|
||||
.setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
return new HandLandmarker(runner, landmarkerOptions.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link HandLandmarker} from a {@link TaskRunner} and a {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private HandLandmarker(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs hand landmarks detection on the provided single image with default image processing
|
||||
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||
* HandLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||
* for input image format.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HandLandmarkerResult detect(MPImage image) {
|
||||
return detect(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs hand landmarks detection on the provided single image. Only use this method when the
|
||||
* {@link HandLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java
|
||||
* doc for input image format.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HandLandmarkerResult detect(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (HandLandmarkerResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs hand landmarks detection on the provided video frame with default image processing
|
||||
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||
* HandLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HandLandmarkerResult detectForVideo(MPImage image, long timestampMs) {
|
||||
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs hand landmarks detection on the provided video frame. Only use this method when the
|
||||
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public HandLandmarkerResult detectForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
return (HandLandmarkerResult)
|
||||
processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform hand landmarks detection with default image processing
|
||||
* options, i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link HandLandmarkerOptions}. Only use this method when the
|
||||
* {@link HandLandmarker } is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the hand landmarker. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(MPImage image, long timestampMs) {
|
||||
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform hand landmarks detection, and the results will be available
|
||||
* via the {@link ResultListener} provided in the {@link HandLandmarkerOptions}. Only use this
|
||||
* method when the {@link HandLandmarker} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the hand landmarker. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void detectAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link HandLandmarker}. */
|
||||
@AutoValue
|
||||
public abstract static class HandLandmarkerOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link HandLandmarkerOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the hand landmarker task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the running mode for the hand landmarker task. Default to the image mode. Hand
|
||||
* landmarker has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for detecting hand landmarks on single image inputs.
|
||||
* <li>VIDEO: The mode for detecting hand landmarks on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for for detecting hand landmarks on a live stream of input
|
||||
* data, such as from camera. In this mode, {@code setResultListener} must be called to
|
||||
* set up a listener to receive the detection results asynchronously.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode value);
|
||||
|
||||
/** Sets the maximum number of hands can be detected by the HandLandmarker. */
|
||||
public abstract Builder setNumHands(Integer value);
|
||||
|
||||
/** Sets minimum confidence score for the hand detection to be considered successful */
|
||||
public abstract Builder setMinHandDetectionConfidence(Float value);
|
||||
|
||||
/** Sets minimum confidence score of hand presence score in the hand landmark detection. */
|
||||
public abstract Builder setMinHandPresenceConfidence(Float value);
|
||||
|
||||
/** Sets the minimum confidence score for the hand tracking to be considered successful. */
|
||||
public abstract Builder setMinTrackingConfidence(Float value);
|
||||
|
||||
/**
|
||||
* Sets the result listener to receive the detection results asynchronously when the hand
|
||||
* landmarker is in the live stream mode.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<HandLandmarkerResult, MPImage> value);
|
||||
|
||||
/** Sets an optional error listener. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract HandLandmarkerOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link HandLandmarkerOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the hand landmarker is
|
||||
* in the live stream mode.
|
||||
*/
|
||||
public final HandLandmarkerOptions build() {
|
||||
HandLandmarkerOptions options = autoBuild();
|
||||
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||
if (!options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The hand landmarker is in the live stream mode, a user-defined result listener"
|
||||
+ " must be provided in HandLandmarkerOptions.");
|
||||
}
|
||||
} else if (options.resultListener().isPresent()) {
|
||||
throw new IllegalArgumentException(
|
||||
"The hand landmarker is in the image or the video mode, a user-defined result"
|
||||
+ " listener shouldn't be provided in HandLandmarkerOptions.");
|
||||
}
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract Optional<Integer> numHands();
|
||||
|
||||
abstract Optional<Float> minHandDetectionConfidence();
|
||||
|
||||
abstract Optional<Float> minHandPresenceConfidence();
|
||||
|
||||
abstract Optional<Float> minTrackingConfidence();
|
||||
|
||||
abstract Optional<ResultListener<HandLandmarkerResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_HandLandmarker_HandLandmarkerOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setNumHands(1)
|
||||
.setMinHandDetectionConfidence(0.5f)
|
||||
.setMinHandPresenceConfidence(0.5f)
|
||||
.setMinTrackingConfidence(0.5f);
|
||||
}
|
||||
|
||||
/** Converts a {@link HandLandmarkerOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder taskOptionsBuilder =
|
||||
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
||||
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
|
||||
.build());
|
||||
|
||||
// Setup HandDetectorGraphOptions.
|
||||
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder
|
||||
handDetectorGraphOptionsBuilder =
|
||||
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder();
|
||||
numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands);
|
||||
minHandDetectionConfidence()
|
||||
.ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence);
|
||||
|
||||
// Setup HandLandmarkerGraphOptions.
|
||||
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder
|
||||
handLandmarksDetectorGraphOptionsBuilder =
|
||||
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder();
|
||||
minHandPresenceConfidence()
|
||||
.ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence);
|
||||
minTrackingConfidence().ifPresent(taskOptionsBuilder::setMinTrackingConfidence);
|
||||
|
||||
taskOptionsBuilder
|
||||
.setHandDetectorGraphOptions(handDetectorGraphOptionsBuilder.build())
|
||||
.setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptionsBuilder.build());
|
||||
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||
* region-of-interest.
|
||||
*/
|
||||
private static void validateImageProcessingOptions(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||
throw new IllegalArgumentException("HandLandmarker doesn't support region-of-interest.");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,109 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.Landmark;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */
|
||||
@AutoValue
|
||||
public abstract class HandLandmarkerResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and
|
||||
* handedness protobuf messages.
|
||||
*
|
||||
* @param landmarksProto a List of {@link NormalizedLandmarkList}
|
||||
* @param worldLandmarksProto a List of {@link LandmarkList}
|
||||
* @param handednessesProto a List of {@link ClassificationList}
|
||||
*/
|
||||
static HandLandmarkerResult create(
|
||||
List<NormalizedLandmarkList> landmarksProto,
|
||||
List<LandmarkList> worldLandmarksProto,
|
||||
List<ClassificationList> handednessesProto,
|
||||
long timestampMs) {
|
||||
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandLandmarks =
|
||||
new ArrayList<>();
|
||||
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandWorldLandmarks =
|
||||
new ArrayList<>();
|
||||
List<List<Category>> multiHandHandednesses = new ArrayList<>();
|
||||
for (NormalizedLandmarkList handLandmarksProto : landmarksProto) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Landmark> handLandmarks =
|
||||
new ArrayList<>();
|
||||
multiHandLandmarks.add(handLandmarks);
|
||||
for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) {
|
||||
handLandmarks.add(
|
||||
com.google.mediapipe.tasks.components.containers.Landmark.create(
|
||||
handLandmarkProto.getX(),
|
||||
handLandmarkProto.getY(),
|
||||
handLandmarkProto.getZ(),
|
||||
true));
|
||||
}
|
||||
}
|
||||
for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) {
|
||||
List<com.google.mediapipe.tasks.components.containers.Landmark> handWorldLandmarks =
|
||||
new ArrayList<>();
|
||||
multiHandWorldLandmarks.add(handWorldLandmarks);
|
||||
for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) {
|
||||
handWorldLandmarks.add(
|
||||
com.google.mediapipe.tasks.components.containers.Landmark.create(
|
||||
handWorldLandmarkProto.getX(),
|
||||
handWorldLandmarkProto.getY(),
|
||||
handWorldLandmarkProto.getZ(),
|
||||
false));
|
||||
}
|
||||
}
|
||||
for (ClassificationList handednessProto : handednessesProto) {
|
||||
List<Category> handedness = new ArrayList<>();
|
||||
multiHandHandednesses.add(handedness);
|
||||
for (Classification classification : handednessProto.getClassificationList()) {
|
||||
handedness.add(
|
||||
Category.create(
|
||||
classification.getScore(),
|
||||
classification.getIndex(),
|
||||
classification.getLabel(),
|
||||
classification.getDisplayName()));
|
||||
}
|
||||
}
|
||||
return new AutoValue_HandLandmarkerResult(
|
||||
timestampMs,
|
||||
Collections.unmodifiableList(multiHandLandmarks),
|
||||
Collections.unmodifiableList(multiHandWorldLandmarks),
|
||||
Collections.unmodifiableList(multiHandHandednesses));
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
/** Hand landmarks of detected hands. */
|
||||
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>> landmarks();
|
||||
|
||||
/** Hand landmarks in world coordniates of detected hands. */
|
||||
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>>
|
||||
worldLandmarks();
|
||||
|
||||
/** Handedness of detected hands. */
|
||||
public abstract List<List<Category>> handednesses();
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.handlandmarkertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="handlandmarkertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.vision.handlandmarkertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,424 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.RectF;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.common.truth.Correspondence;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.Landmark;
|
||||
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarker.HandLandmarkerOptions;
|
||||
import java.io.InputStream;
|
||||
import java.util.Arrays;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link HandLandmarker}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({HandLandmarkerTest.General.class, HandLandmarkerTest.RunningModeTest.class})
|
||||
public class HandLandmarkerTest {
|
||||
private static final String HAND_LANDMARKER_BUNDLE_ASSET_FILE = "hand_landmarker.task";
|
||||
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
|
||||
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
|
||||
private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg";
|
||||
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
|
||||
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
|
||||
private static final String POINTING_UP_ROTATED_LANDMARKS = "pointing_up_rotated_landmarks.pb";
|
||||
private static final String TAG = "Hand Landmarker Test";
|
||||
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
|
||||
private static final int IMAGE_WIDTH = 382;
|
||||
private static final int IMAGE_HEIGHT = 406;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends HandLandmarkerTest {
|
||||
|
||||
@Test
|
||||
public void detect_successWithValidModels() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.build();
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
HandLandmarkerResult actualResult =
|
||||
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE));
|
||||
HandLandmarkerResult expectedResult =
|
||||
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithEmptyResult() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.build();
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
HandLandmarkerResult actualResult =
|
||||
handLandmarker.detect(getImageFromAsset(NO_HANDS_IMAGE));
|
||||
assertThat(actualResult.landmarks()).isEmpty();
|
||||
assertThat(actualResult.worldLandmarks()).isEmpty();
|
||||
assertThat(actualResult.handednesses()).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void detect_successWithNumHands() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setNumHands(2)
|
||||
.build();
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
HandLandmarkerResult actualResult =
|
||||
handLandmarker.detect(getImageFromAsset(TWO_HANDS_IMAGE));
|
||||
assertThat(actualResult.handednesses()).hasSize(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithRotation() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setNumHands(1)
|
||||
.build();
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||
HandLandmarkerResult actualResult =
|
||||
handLandmarker.detect(
|
||||
getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions);
|
||||
HandLandmarkerResult expectedResult =
|
||||
getExpectedHandLandmarkerResult(POINTING_UP_ROTATED_LANDMARKS);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithRegionOfInterest() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setNumHands(1)
|
||||
.build();
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
ImageProcessingOptions imageProcessingOptions =
|
||||
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("HandLandmarker doesn't support region-of-interest");
|
||||
}
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends HandLandmarkerTest {
|
||||
@Test
|
||||
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
|
||||
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(mode)
|
||||
.setResultListener((HandLandmarkerResults, inputImage) -> {})
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener shouldn't be provided");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||
IllegalArgumentException exception =
|
||||
assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder()
|
||||
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
.build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.build());
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("a user-defined result listener must be provided");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
handLandmarker.detectForVideo(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
handLandmarker.detectAsync(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
handLandmarker.detectAsync(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((HandLandmarkerResults, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
handLandmarker.detectForVideo(
|
||||
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithImageMode() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
HandLandmarkerResult actualResult =
|
||||
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE));
|
||||
HandLandmarkerResult expectedResult =
|
||||
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithVideoMode() throws Exception {
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
HandLandmarkerResult expectedResult =
|
||||
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
HandLandmarkerResult actualResult =
|
||||
handLandmarker.detectForVideo(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i);
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
|
||||
HandLandmarkerResult expectedResult =
|
||||
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
actualResult, expectedResult);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
handLandmarker.detectAsync(image, /*timestampsMs=*/ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> handLandmarker.detectAsync(image, /*timestampsMs=*/ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void recognize_successWithLiveSteamMode() throws Exception {
|
||||
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
|
||||
HandLandmarkerResult expectedResult =
|
||||
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||
HandLandmarkerOptions options =
|
||||
HandLandmarkerOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
actualResult, expectedResult);
|
||||
assertImageSizeIsExpected(inputImage);
|
||||
})
|
||||
.build();
|
||||
try (HandLandmarker handLandmarker =
|
||||
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
handLandmarker.detectAsync(image, /*timestampsMs=*/ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
private static HandLandmarkerResult getExpectedHandLandmarkerResult(
|
||||
String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
LandmarksDetectionResult landmarksDetectionResultProto =
|
||||
LandmarksDetectionResult.parser().parseFrom(istr);
|
||||
return HandLandmarkerResult.create(
|
||||
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
|
||||
Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()),
|
||||
Arrays.asList(landmarksDetectionResultProto.getClassifications()),
|
||||
/*timestampMs=*/ 0);
|
||||
}
|
||||
|
||||
private static void assertActualResultApproximatelyEqualsToExpectedResult(
|
||||
HandLandmarkerResult actualResult, HandLandmarkerResult expectedResult) {
|
||||
// Expects to have the same number of hands detected.
|
||||
assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size());
|
||||
assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size());
|
||||
assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size());
|
||||
|
||||
// Actual landmarks match expected landmarks.
|
||||
assertThat(actualResult.landmarks().get(0))
|
||||
.comparingElementsUsing(
|
||||
Correspondence.from(
|
||||
(Correspondence.BinaryPredicate<Landmark, Landmark>)
|
||||
(actual, expected) -> {
|
||||
return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
|
||||
.compare(actual.x(), expected.x())
|
||||
&& Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
|
||||
.compare(actual.y(), expected.y());
|
||||
},
|
||||
"landmarks approximately equal to"))
|
||||
.containsExactlyElementsIn(expectedResult.landmarks().get(0));
|
||||
|
||||
// Actual handedness matches expected handedness.
|
||||
Category actualTopHandedness = actualResult.handednesses().get(0).get(0);
|
||||
Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0);
|
||||
assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index());
|
||||
assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName());
|
||||
}
|
||||
|
||||
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
||||
assertThat(inputImage).isNotNull();
|
||||
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
|
||||
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
|
||||
}
|
||||
}
|
|
@ -53,7 +53,13 @@ class LandmarksDetectionResult:
|
|||
def to_pb2(self) -> _LandmarksDetectionResultProto:
|
||||
"""Generates a LandmarksDetectionResult protobuf object."""
|
||||
|
||||
landmarks = _NormalizedLandmarkListProto()
|
||||
classifications = _ClassificationListProto()
|
||||
world_landmarks = _LandmarkListProto()
|
||||
|
||||
for landmark in self.landmarks:
|
||||
landmarks.landmark.append(landmark.to_pb2())
|
||||
|
||||
for category in self.categories:
|
||||
classifications.classification.append(
|
||||
_ClassificationProto(
|
||||
|
@ -63,9 +69,9 @@ class LandmarksDetectionResult:
|
|||
display_name=category.display_name))
|
||||
|
||||
return _LandmarksDetectionResultProto(
|
||||
landmarks=_NormalizedLandmarkListProto(self.landmarks),
|
||||
landmarks=landmarks,
|
||||
classifications=classifications,
|
||||
world_landmarks=_LandmarkListProto(self.world_landmarks),
|
||||
world_landmarks=world_landmarks,
|
||||
rect=self.rect.to_pb2())
|
||||
|
||||
@classmethod
|
||||
|
@ -73,9 +79,11 @@ class LandmarksDetectionResult:
|
|||
def create_from_pb2(
|
||||
cls,
|
||||
pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult':
|
||||
"""Creates a `LandmarksDetectionResult` object from the given protobuf object.
|
||||
"""
|
||||
"""Creates a `LandmarksDetectionResult` object from the given protobuf object."""
|
||||
categories = []
|
||||
landmarks = []
|
||||
world_landmarks = []
|
||||
|
||||
for classification in pb2_obj.classifications.classification:
|
||||
categories.append(
|
||||
category_module.Category(
|
||||
|
@ -83,14 +91,14 @@ class LandmarksDetectionResult:
|
|||
index=classification.index,
|
||||
category_name=classification.label,
|
||||
display_name=classification.display_name))
|
||||
|
||||
for landmark in pb2_obj.landmarks.landmark:
|
||||
landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
|
||||
|
||||
for landmark in pb2_obj.world_landmarks.landmark:
|
||||
world_landmarks.append(_Landmark.create_from_pb2(landmark))
|
||||
return LandmarksDetectionResult(
|
||||
landmarks=[
|
||||
_NormalizedLandmark.create_from_pb2(landmark)
|
||||
for landmark in pb2_obj.landmarks.landmark
|
||||
],
|
||||
landmarks=landmarks,
|
||||
categories=categories,
|
||||
world_landmarks=[
|
||||
_Landmark.create_from_pb2(landmark)
|
||||
for landmark in pb2_obj.world_landmarks.landmark
|
||||
],
|
||||
world_landmarks=world_landmarks,
|
||||
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
||||
|
|
|
@ -12,9 +12,9 @@ py_library(
|
|||
srcs = [
|
||||
"metadata_info.py",
|
||||
],
|
||||
srcs_version = "PY3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":writer_utils",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||
"//mediapipe/tasks/metadata:schema_py",
|
||||
],
|
||||
|
|
|
@ -14,12 +14,14 @@
|
|||
# ==============================================================================
|
||||
"""Helper classes for common model metadata information."""
|
||||
|
||||
import collections
|
||||
import csv
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
from typing import List, Optional, Type, Union
|
||||
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||
|
||||
# Min and max values for UINT8 tensors.
|
||||
_MIN_UINT8 = 0
|
||||
|
@ -267,6 +269,86 @@ class RegexTokenizerMd:
|
|||
return tokenizer
|
||||
|
||||
|
||||
class BertTokenizerMd:
|
||||
"""A container for the Bert tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_file_path: str):
|
||||
"""Initializes a BertTokenizerMd object.
|
||||
|
||||
Args:
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
self._vocab_file_path = vocab_file_path
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||
"""Creates the Bert tokenizer metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the Bert tokenizer metadata.
|
||||
"""
|
||||
vocab = _metadata_fb.AssociatedFileT()
|
||||
vocab.name = self._vocab_file_path
|
||||
vocab.description = _VOCAB_FILE_DESCRIPTION
|
||||
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||
tokenizer = _metadata_fb.ProcessUnitT()
|
||||
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
|
||||
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
|
||||
tokenizer.options.vocabFile = [vocab]
|
||||
return tokenizer
|
||||
|
||||
|
||||
class SentencePieceTokenizerMd:
|
||||
"""A container for the sentence piece tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
"""
|
||||
|
||||
_SP_MODEL_DESCRIPTION = "The sentence piece model file."
|
||||
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
|
||||
" This file is optional during tokenization, while the sentence piece "
|
||||
"model is mandatory.")
|
||||
|
||||
def __init__(self,
|
||||
sentence_piece_model_path: str,
|
||||
vocab_file_path: Optional[str] = None):
|
||||
"""Initializes a SentencePieceTokenizerMd object.
|
||||
|
||||
Args:
|
||||
sentence_piece_model_path: path to the sentence piece model file.
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
self._sentence_piece_model_path = sentence_piece_model_path
|
||||
self._vocab_file_path = vocab_file_path
|
||||
|
||||
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||
"""Creates the sentence piece tokenizer metadata based on the information.
|
||||
|
||||
Returns:
|
||||
A Flatbuffers Python object of the sentence piece tokenizer metadata.
|
||||
"""
|
||||
tokenizer = _metadata_fb.ProcessUnitT()
|
||||
tokenizer.optionsType = (
|
||||
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
|
||||
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
|
||||
|
||||
sp_model = _metadata_fb.AssociatedFileT()
|
||||
sp_model.name = self._sentence_piece_model_path
|
||||
sp_model.description = self._SP_MODEL_DESCRIPTION
|
||||
tokenizer.options.sentencePieceModel = [sp_model]
|
||||
if self._vocab_file_path:
|
||||
vocab = _metadata_fb.AssociatedFileT()
|
||||
vocab.name = self._vocab_file_path
|
||||
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
|
||||
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||
tokenizer.options.vocabFile = [vocab]
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TensorMd:
|
||||
"""A container for common tensor metadata information.
|
||||
|
||||
|
@ -486,6 +568,145 @@ class InputTextTensorMd(TensorMd):
|
|||
return tensor_metadata
|
||||
|
||||
|
||||
def _get_file_paths(files: List[_metadata_fb.AssociatedFileT]) -> List[str]:
|
||||
"""Gets file paths from a list of associated files."""
|
||||
if not files:
|
||||
return []
|
||||
return [file.name for file in files]
|
||||
|
||||
|
||||
def _get_tokenizer_associated_files(
|
||||
tokenizer_options: Optional[
|
||||
Union[_metadata_fb.BertTokenizerOptionsT,
|
||||
_metadata_fb.SentencePieceTokenizerOptionsT]]
|
||||
) -> List[str]:
|
||||
"""Gets a list of associated files packed in the tokenizer_options.
|
||||
|
||||
Args:
|
||||
tokenizer_options: a tokenizer metadata object. Support the following
|
||||
tokenizer types:
|
||||
1. BertTokenizerOptions:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
2. SentencePieceTokenizerOptions:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
|
||||
Returns:
|
||||
A list of associated files included in tokenizer_options.
|
||||
"""
|
||||
|
||||
if not tokenizer_options:
|
||||
return []
|
||||
|
||||
if isinstance(tokenizer_options, _metadata_fb.BertTokenizerOptionsT):
|
||||
return _get_file_paths(tokenizer_options.vocabFile)
|
||||
elif isinstance(tokenizer_options,
|
||||
_metadata_fb.SentencePieceTokenizerOptionsT):
|
||||
return _get_file_paths(tokenizer_options.vocabFile) + _get_file_paths(
|
||||
tokenizer_options.sentencePieceModel)
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class BertInputTensorsMd:
|
||||
"""A container for the input tensor metadata information of Bert models."""
|
||||
|
||||
_IDS_NAME = "ids"
|
||||
_IDS_DESCRIPTION = "Tokenized ids of the input text."
|
||||
_MASK_NAME = "mask"
|
||||
_MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
|
||||
"tokens.")
|
||||
_SEGMENT_IDS_NAME = "segment_ids"
|
||||
_SEGMENT_IDS_DESCRIPTION = (
|
||||
"0 for the first sequence, 1 for the second sequence if exists.")
|
||||
|
||||
def __init__(self,
|
||||
model_buffer: bytearray,
|
||||
ids_name: str,
|
||||
mask_name: str,
|
||||
segment_name: str,
|
||||
tokenizer_md: Union[None, BertTokenizerMd,
|
||||
SentencePieceTokenizerMd] = None):
|
||||
"""Initializes a BertInputTensorsMd object.
|
||||
|
||||
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||
in the TFLite schema, which help to determine the tensor order when
|
||||
populating metadata.
|
||||
|
||||
Args:
|
||||
model_buffer: valid buffer of the model file.
|
||||
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||
the input text.
|
||||
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||
real tokens and `0` for padding tokens.
|
||||
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||
first sequence, and `1` stands for the second sequence if exists.
|
||||
tokenizer_md: information of the tokenizer used to process the input
|
||||
string, if any. Supported tokenizers are: `BertTokenizer` [1] and
|
||||
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
|
||||
refer to `InputTensorsMd`.
|
||||
[1]:
|
||||
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
|
||||
[2]:
|
||||
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
|
||||
[3]:
|
||||
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
|
||||
"""
|
||||
# Verify that tflite_input_names (read from the model) and
|
||||
# input_name (collected from users) are aligned.
|
||||
tflite_input_names = writer_utils.get_input_tensor_names(model_buffer)
|
||||
input_names = [ids_name, mask_name, segment_name]
|
||||
if collections.Counter(tflite_input_names) != collections.Counter(
|
||||
input_names):
|
||||
raise ValueError(
|
||||
f"The input tensor names ({input_names}) do not match the tensor "
|
||||
f"names read from the model ({tflite_input_names}).")
|
||||
|
||||
ids_md = TensorMd(
|
||||
name=self._IDS_NAME,
|
||||
description=self._IDS_DESCRIPTION,
|
||||
tensor_name=ids_name)
|
||||
|
||||
mask_md = TensorMd(
|
||||
name=self._MASK_NAME,
|
||||
description=self._MASK_DESCRIPTION,
|
||||
tensor_name=mask_name)
|
||||
|
||||
segment_ids_md = TensorMd(
|
||||
name=self._SEGMENT_IDS_NAME,
|
||||
description=self._SEGMENT_IDS_DESCRIPTION,
|
||||
tensor_name=segment_name)
|
||||
|
||||
self._input_md = [ids_md, mask_md, segment_ids_md]
|
||||
|
||||
if not isinstance(tokenizer_md,
|
||||
(type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
|
||||
raise ValueError(
|
||||
f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
|
||||
)
|
||||
|
||||
self._tokenizer_md = tokenizer_md
|
||||
|
||||
def create_input_process_unit_metadata(
|
||||
self) -> List[_metadata_fb.ProcessUnitT]:
|
||||
"""Creates the input process unit metadata."""
|
||||
if self._tokenizer_md:
|
||||
return [self._tokenizer_md.create_metadata()]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_tokenizer_associated_files(self) -> List[str]:
|
||||
"""Gets the associated files that are packed in the tokenizer."""
|
||||
if self._tokenizer_md:
|
||||
return _get_tokenizer_associated_files(
|
||||
self._tokenizer_md.create_metadata().options)
|
||||
else:
|
||||
return []
|
||||
|
||||
@property
|
||||
def input_md(self) -> List[TensorMd]:
|
||||
return self._input_md
|
||||
|
||||
|
||||
class ClassificationTensorMd(TensorMd):
|
||||
"""A container for the classification tensor metadata information.
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ import csv
|
|||
import dataclasses
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import flatbuffers
|
||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||
|
@ -101,6 +101,34 @@ class RegexTokenizer:
|
|||
vocab_file_path: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BertTokenizer:
|
||||
"""Parameters of the Bert tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
|
||||
Attributes:
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
vocab_file_path: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SentencePieceTokenizer:
|
||||
"""Parameters of the sentence piece tokenizer tokenizer [1] metadata information.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
|
||||
Attributes:
|
||||
sentence_piece_model_path: path to the sentence piece model file.
|
||||
vocab_file_path: path to the vocabulary file.
|
||||
"""
|
||||
sentence_piece_model_path: str
|
||||
vocab_file_path: Optional[str] = None
|
||||
|
||||
|
||||
class Labels(object):
|
||||
"""Simple container holding classification labels of a particular tensor.
|
||||
|
||||
|
@ -282,7 +310,9 @@ def _create_metadata_buffer(
|
|||
model_buffer: bytearray,
|
||||
general_md: Optional[metadata_info.GeneralMd] = None,
|
||||
input_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||
output_md: Optional[List[metadata_info.TensorMd]] = None) -> bytearray:
|
||||
output_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None
|
||||
) -> bytearray:
|
||||
"""Creates a buffer of the metadata.
|
||||
|
||||
Args:
|
||||
|
@ -290,7 +320,9 @@ def _create_metadata_buffer(
|
|||
general_md: general information about the model.
|
||||
input_md: metadata information of the input tensors.
|
||||
output_md: metadata information of the output tensors.
|
||||
|
||||
input_process_units: a lists of metadata of the input process units [1].
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
|
||||
Returns:
|
||||
A buffer of the metadata.
|
||||
|
||||
|
@ -325,6 +357,8 @@ def _create_metadata_buffer(
|
|||
subgraph_metadata = metadata_fb.SubGraphMetadataT()
|
||||
subgraph_metadata.inputTensorMetadata = input_metadata
|
||||
subgraph_metadata.outputTensorMetadata = output_metadata
|
||||
if input_process_units:
|
||||
subgraph_metadata.inputProcessUnits = input_process_units
|
||||
|
||||
# Create the whole model metadata.
|
||||
if general_md is None:
|
||||
|
@ -366,6 +400,7 @@ class MetadataWriter(object):
|
|||
self._model_buffer = model_buffer
|
||||
self._general_md = None
|
||||
self._input_mds = []
|
||||
self._input_process_units = []
|
||||
self._output_mds = []
|
||||
self._associated_files = []
|
||||
self._temp_folder = tempfile.TemporaryDirectory()
|
||||
|
@ -416,7 +451,7 @@ class MetadataWriter(object):
|
|||
description: Description of the input tensor.
|
||||
|
||||
Returns:
|
||||
The MetaWriter instance, can be used for chained operation.
|
||||
The MetadataWriter instance, can be used for chained operation.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||
|
@ -448,7 +483,7 @@ class MetadataWriter(object):
|
|||
description: Description of the input tensor.
|
||||
|
||||
Returns:
|
||||
The MetaWriter instance, can be used for chained operation.
|
||||
The MetadataWriter instance, can be used for chained operation.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
|
@ -462,6 +497,63 @@ class MetadataWriter(object):
|
|||
self._associated_files.append(regex_tokenizer.vocab_file_path)
|
||||
return self
|
||||
|
||||
def add_bert_text_input(self, tokenizer: Union[BertTokenizer,
|
||||
SentencePieceTokenizer],
|
||||
ids_name: str, mask_name: str,
|
||||
segment_name: str) -> 'MetadataWriter':
|
||||
"""Adds an metadata for the text input with bert / sentencepiece tokenizer.
|
||||
|
||||
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||
in the TFLite schema, which help to determine the tensor order when
|
||||
populating metadata.
|
||||
|
||||
Args:
|
||||
tokenizer: information of the tokenizer used to process the input string,
|
||||
if any. Supported tokenziers are: `BertTokenizer` [1] and
|
||||
`SentencePieceTokenizer` [2].
|
||||
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||
the input text.
|
||||
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||
real tokens and `0` for padding tokens.
|
||||
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||
first sequence, and `1` stands for the second sequence if exists.
|
||||
|
||||
Returns:
|
||||
The MetadataWriter instance, can be used for chained operation.
|
||||
|
||||
Raises:
|
||||
ValueError: if the type tokenizer is not BertTokenizer or
|
||||
SentencePieceTokenizer.
|
||||
|
||||
[1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
"""
|
||||
if isinstance(tokenizer, BertTokenizer):
|
||||
tokenizer_md = metadata_info.BertTokenizerMd(
|
||||
vocab_file_path=tokenizer.vocab_file_path)
|
||||
elif isinstance(tokenizer, SentencePieceTokenizer):
|
||||
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
|
||||
sentence_piece_model_path=tokenizer.sentence_piece_model_path,
|
||||
vocab_file_path=tokenizer.vocab_file_path)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'The type of tokenizer, {type(tokenizer)}, is unsupported')
|
||||
bert_input_md = metadata_info.BertInputTensorsMd(
|
||||
self._model_buffer,
|
||||
ids_name,
|
||||
mask_name,
|
||||
segment_name,
|
||||
tokenizer_md=tokenizer_md)
|
||||
|
||||
self._input_mds.extend(bert_input_md.input_md)
|
||||
self._associated_files.extend(
|
||||
bert_input_md.get_tokenizer_associated_files())
|
||||
self._input_process_units.extend(
|
||||
bert_input_md.create_input_process_unit_metadata())
|
||||
return self
|
||||
|
||||
def add_classification_output(
|
||||
self,
|
||||
labels: Optional[Labels] = None,
|
||||
|
@ -546,7 +638,8 @@ class MetadataWriter(object):
|
|||
model_buffer=self._model_buffer,
|
||||
general_md=self._general_md,
|
||||
input_md=self._input_mds,
|
||||
output_md=self._output_mds)
|
||||
output_md=self._output_mds,
|
||||
input_process_units=self._input_process_units)
|
||||
populator.load_metadata_buffer(metadata_buffer)
|
||||
if self._associated_files:
|
||||
populator.load_associated_files(self._associated_files)
|
||||
|
|
|
@ -14,11 +14,18 @@
|
|||
# ==============================================================================
|
||||
"""Writes metadata and label file to the Text classifier models."""
|
||||
|
||||
from typing import Union
|
||||
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
|
||||
_MODEL_NAME = "TextClassifier"
|
||||
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
|
||||
|
||||
# The input tensor names of models created by Model Maker.
|
||||
_DEFAULT_ID_NAME = "serving_default_input_word_ids:0"
|
||||
_DEFAULT_MASK_NAME = "serving_default_input_mask:0"
|
||||
_DEFAULT_SEGMENT_ID_NAME = "serving_default_input_type_ids:0"
|
||||
|
||||
|
||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||
"""MetadataWriter to write the metadata into the text classifier."""
|
||||
|
@ -62,3 +69,51 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
|||
writer.add_regex_text_input(regex_tokenizer)
|
||||
writer.add_classification_output(labels)
|
||||
return cls(writer)
|
||||
|
||||
@classmethod
|
||||
def create_for_bert_model(
|
||||
cls,
|
||||
model_buffer: bytearray,
|
||||
tokenizer: Union[metadata_writer.BertTokenizer,
|
||||
metadata_writer.SentencePieceTokenizer],
|
||||
labels: metadata_writer.Labels,
|
||||
ids_name: str = _DEFAULT_ID_NAME,
|
||||
mask_name: str = _DEFAULT_MASK_NAME,
|
||||
segment_name: str = _DEFAULT_SEGMENT_ID_NAME,
|
||||
) -> "MetadataWriter":
|
||||
"""Creates MetadataWriter for models with {Bert/SentencePiece}Tokenizer.
|
||||
|
||||
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||
in the TFLite schema, which help to determine the tensor order when
|
||||
populating metadata. The default values come from Model Maker.
|
||||
|
||||
Args:
|
||||
model_buffer: valid buffer of the model file.
|
||||
tokenizer: information of the tokenizer used to process the input string,
|
||||
if any. Supported tokenziers are: `BertTokenizer` [1] and
|
||||
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
|
||||
refer to `create_for_regex_model`.
|
||||
labels: an instance of Labels helper class used in the output
|
||||
classification tensor [4].
|
||||
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||
the input text.
|
||||
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||
real tokens and `0` for padding tokens.
|
||||
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||
first sequence, and `1` stands for the second sequence if exists. [1]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||
[2]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||
[3]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||
[4]:
|
||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||
|
||||
Returns:
|
||||
A MetadataWriter object.
|
||||
"""
|
||||
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||
writer.add_bert_text_input(tokenizer, ids_name, mask_name, segment_name)
|
||||
writer.add_classification_output(labels)
|
||||
return cls(writer)
|
||||
|
|
|
@ -367,6 +367,42 @@ class ScoreThresholdingMdTest(absltest.TestCase):
|
|||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
class BertTokenizerMdTest(absltest.TestCase):
|
||||
|
||||
_VOCAB_FILE = "vocab.txt"
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "bert_tokenizer_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE)
|
||||
tokenizer_metadata = tokenizer_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
|
||||
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
class SentencePieceTokenizerMdTest(absltest.TestCase):
|
||||
|
||||
_VOCAB_FILE = "vocab.txt"
|
||||
_SP_MODEL = "sp.model"
|
||||
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, "sentence_piece_tokenizer_meta.json"))
|
||||
|
||||
def test_create_metadata_should_succeed(self):
|
||||
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
|
||||
self._SP_MODEL, self._VOCAB_FILE)
|
||||
tokenizer_metadata = tokenizer_md.create_metadata()
|
||||
|
||||
metadata_json = _metadata.convert_to_json(
|
||||
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
|
||||
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
||||
def _create_dummy_model_metadata_with_tensor(
|
||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||
# Create a dummy model using the tensor metadata.
|
||||
|
|
|
@ -21,28 +21,64 @@ from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
|
|||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
|
||||
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
||||
_REGEX_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
||||
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||
"movie_review_labels.txt")
|
||||
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
||||
_REGEX_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
||||
_REGEX_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
||||
|
||||
_BERT_MODEL = test_utils.get_test_data_path(
|
||||
_TEST_DIR + "bert_text_classifier_no_metadata.tflite")
|
||||
_BERT_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||
"mobilebert_vocab.txt")
|
||||
_SP_MODEL_FILE = test_utils.get_test_data_path(_TEST_DIR + "30k-clean.model")
|
||||
_BERT_JSON_FILE = test_utils.get_test_data_path(
|
||||
_TEST_DIR + "bert_text_classifier_with_bert_tokenizer.json")
|
||||
_SENTENCE_PIECE_JSON_FILE = test_utils.get_test_data_path(
|
||||
_TEST_DIR + "bert_text_classifier_with_sentence_piece.json")
|
||||
|
||||
|
||||
class TextClassifierTest(absltest.TestCase):
|
||||
|
||||
def test_write_metadata(self,):
|
||||
with open(_MODEL, "rb") as f:
|
||||
def test_write_metadata_for_regex_model(self):
|
||||
with open(_REGEX_MODEL, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = text_classifier.MetadataWriter.create_for_regex_model(
|
||||
model_buffer,
|
||||
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||
delim_regex_pattern=_DELIM_REGEX_PATTERN,
|
||||
vocab_file_path=_VOCAB_FILE),
|
||||
vocab_file_path=_REGEX_VOCAB_FILE),
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||
_, metadata_json = writer.populate()
|
||||
|
||||
with open(_JSON_FILE, "r") as f:
|
||||
with open(_REGEX_JSON_FILE, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_write_metadata_for_bert(self):
|
||||
with open(_BERT_MODEL, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = text_classifier.MetadataWriter.create_for_bert_model(
|
||||
model_buffer,
|
||||
tokenizer=metadata_writer.BertTokenizer(_BERT_VOCAB_FILE),
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||
_, metadata_json = writer.populate()
|
||||
|
||||
with open(_BERT_JSON_FILE, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
def test_write_metadata_for_sentence_piece(self):
|
||||
with open(_BERT_MODEL, "rb") as f:
|
||||
model_buffer = f.read()
|
||||
writer = text_classifier.MetadataWriter.create_for_bert_model(
|
||||
model_buffer,
|
||||
tokenizer=metadata_writer.SentencePieceTokenizer(_SP_MODEL_FILE),
|
||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||
_, metadata_json = writer.populate()
|
||||
|
||||
with open(_SENTENCE_PIECE_JSON_FILE, "r") as f:
|
||||
expected_json = f.read()
|
||||
self.assertEqual(metadata_json, expected_json)
|
||||
|
||||
|
|
|
@ -94,3 +94,26 @@ py_test(
|
|||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "hand_landmarker_test",
|
||||
srcs = ["hand_landmarker_test.py"],
|
||||
data = [
|
||||
"//mediapipe/tasks/testdata/vision:test_images",
|
||||
"//mediapipe/tasks/testdata/vision:test_models",
|
||||
"//mediapipe/tasks/testdata/vision:test_protos",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:landmark",
|
||||
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
|
||||
"//mediapipe/tasks/python/components/containers:rect",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
],
|
||||
)
|
||||
|
|
428
mediapipe/tasks/python/test/vision/hand_landmarker_test.py
Normal file
428
mediapipe/tasks/python/test/vision/hand_landmarker_test.py
Normal file
|
@ -0,0 +1,428 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for hand landmarker."""
|
||||
|
||||
import enum
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from google.protobuf import text_format
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
|
||||
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
from mediapipe.tasks.python.vision import hand_landmarker
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Rect = rect_module.Rect
|
||||
_Landmark = landmark_module.Landmark
|
||||
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
|
||||
_Image = image_module.Image
|
||||
_HandLandmarker = hand_landmarker.HandLandmarker
|
||||
_HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
||||
_HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
||||
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
|
||||
_HAND_LANDMARKER_BUNDLE_ASSET_FILE = 'hand_landmarker.task'
|
||||
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
|
||||
_TWO_HANDS_IMAGE = 'right_hands.jpg'
|
||||
_THUMB_UP_IMAGE = 'thumb_up.jpg'
|
||||
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
|
||||
_POINTING_UP_IMAGE = 'pointing_up.jpg'
|
||||
_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt'
|
||||
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
|
||||
_POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
|
||||
_LANDMARKS_ERROR_TOLERANCE = 0.03
|
||||
_HANDEDNESS_MARGIN = 0.05
|
||||
|
||||
|
||||
def _get_expected_hand_landmarker_result(
|
||||
file_path: str) -> _HandLandmarkerResult:
|
||||
landmarks_detection_result_file_path = test_utils.get_test_data_path(
|
||||
file_path)
|
||||
with open(landmarks_detection_result_file_path, 'rb') as f:
|
||||
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
|
||||
# Use this if a .pb file is available.
|
||||
# landmarks_detection_result_proto.ParseFromString(f.read())
|
||||
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||
landmarks_detection_result_proto)
|
||||
return _HandLandmarkerResult(
|
||||
handedness=[landmarks_detection_result.categories],
|
||||
hand_landmarks=[landmarks_detection_result.landmarks],
|
||||
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
|
||||
|
||||
|
||||
class ModelFileType(enum.Enum):
|
||||
FILE_CONTENT = 1
|
||||
FILE_NAME = 2
|
||||
|
||||
|
||||
class HandLandmarkerTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_THUMB_UP_IMAGE))
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
_HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||
|
||||
def _assert_actual_result_approximately_matches_expected_result(
|
||||
self, actual_result: _HandLandmarkerResult,
|
||||
expected_result: _HandLandmarkerResult):
|
||||
# Expects to have the same number of hands detected.
|
||||
self.assertLen(actual_result.hand_landmarks,
|
||||
len(expected_result.hand_landmarks))
|
||||
self.assertLen(actual_result.hand_world_landmarks,
|
||||
len(expected_result.hand_world_landmarks))
|
||||
self.assertLen(actual_result.handedness, len(expected_result.handedness))
|
||||
# Actual landmarks match expected landmarks.
|
||||
self.assertLen(actual_result.hand_landmarks[0],
|
||||
len(expected_result.hand_landmarks[0]))
|
||||
actual_landmarks = actual_result.hand_landmarks[0]
|
||||
expected_landmarks = expected_result.hand_landmarks[0]
|
||||
for i, rename_me in enumerate(actual_landmarks):
|
||||
self.assertAlmostEqual(
|
||||
rename_me.x,
|
||||
expected_landmarks[i].x,
|
||||
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||
self.assertAlmostEqual(
|
||||
rename_me.y,
|
||||
expected_landmarks[i].y,
|
||||
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||
# Actual handedness matches expected handedness.
|
||||
actual_top_handedness = actual_result.handedness[0][0]
|
||||
expected_top_handedness = expected_result.handedness[0][0]
|
||||
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
|
||||
self.assertEqual(actual_top_handedness.category_name,
|
||||
expected_top_handedness.category_name)
|
||||
self.assertAlmostEqual(
|
||||
actual_top_handedness.score,
|
||||
expected_top_handedness.score,
|
||||
delta=_HANDEDNESS_MARGIN)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
with _HandLandmarker.create_from_model_path(self.model_path) as landmarker:
|
||||
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||
# Creates with options containing model file successfully.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
# Invalid empty model path.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
_HandLandmarker.create_from_options(options)
|
||||
|
||||
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||
# Creates with options containing model content successfully.
|
||||
with open(self.model_path, 'rb') as f:
|
||||
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
landmarker = _HandLandmarker.create_from_options(options)
|
||||
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(ModelFileType.FILE_CONTENT,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
|
||||
def test_detect(self, model_file_type, expected_detection_result):
|
||||
# Creates hand landmarker.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
landmarker = _HandLandmarker.create_from_options(options)
|
||||
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
detection_result, expected_detection_result)
|
||||
# Closes the hand landmarker explicitly when the hand landmarker is not used
|
||||
# in a context.
|
||||
landmarker.close()
|
||||
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(ModelFileType.FILE_CONTENT,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
|
||||
def test_detect_in_context(self, model_file_type, expected_detection_result):
|
||||
# Creates hand landmarker.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_content = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(self.test_image)
|
||||
# Comparing results.
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
detection_result, expected_detection_result)
|
||||
|
||||
def test_detect_succeeds_with_num_hands(self):
|
||||
# Creates hand landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options, num_hands=2)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the two hands image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_TWO_HANDS_IMAGE))
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(test_image)
|
||||
# Comparing results.
|
||||
self.assertLen(detection_result.handedness, 2)
|
||||
|
||||
def test_detect_succeeds_with_rotation(self):
|
||||
# Creates hand landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the pointing up rotated image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_POINTING_UP_ROTATED_IMAGE))
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(test_image, image_processing_options)
|
||||
expected_detection_result = _get_expected_hand_landmarker_result(
|
||||
_POINTING_UP_ROTATED_LANDMARKS)
|
||||
# Comparing results.
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
detection_result, expected_detection_result)
|
||||
|
||||
def test_detect_fails_with_region_of_interest(self):
|
||||
# Creates hand landmarker.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
options = _HandLandmarkerOptions(base_options=base_options)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "This task doesn't support region-of-interest."):
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
region_of_interest=_Rect(0, 0, 1, 1))
|
||||
# Attempt to perform hand landmarks detection on the cropped input.
|
||||
landmarker.detect(self.test_image, image_processing_options)
|
||||
|
||||
def test_empty_detection_outputs(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path))
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
# Load the image with no hands.
|
||||
no_hands_test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(_NO_HANDS_IMAGE))
|
||||
# Performs hand landmarks detection on the input.
|
||||
detection_result = landmarker.detect(no_hands_test_image)
|
||||
self.assertEmpty(detection_result.hand_landmarks)
|
||||
self.assertEmpty(detection_result.hand_world_landmarks)
|
||||
self.assertEmpty(detection_result.handedness)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
with _HandLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||
def test_illegal_result_callback(self, running_mode):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
with _HandLandmarker.create_from_options(options) as unused_landmarker:
|
||||
pass
|
||||
|
||||
def test_calling_detect_for_video_in_image_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_async_in_image_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_calling_detect_in_video_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_async_in_video_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
unused_result = landmarker.detect_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(_THUMB_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(_POINTING_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
|
||||
(_POINTING_UP_ROTATED_IMAGE, -90,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
|
||||
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
|
||||
def test_detect_for_video(self, image_path, rotation, expected_result):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path))
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation)
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
result = landmarker.detect_for_video(test_image, timestamp,
|
||||
image_processing_options)
|
||||
if result.hand_landmarks and result.hand_world_landmarks and result.handedness:
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
result, expected_result)
|
||||
else:
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_calling_detect_in_live_stream_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
landmarker.detect(self.test_image)
|
||||
|
||||
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
landmarker.detect_for_video(self.test_image, 0)
|
||||
|
||||
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
landmarker.detect_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
landmarker.detect_async(self.test_image, 0)
|
||||
|
||||
@parameterized.parameters(
|
||||
(_THUMB_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||
(_POINTING_UP_IMAGE, 0,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
|
||||
(_POINTING_UP_ROTATED_IMAGE, -90,
|
||||
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
|
||||
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
|
||||
def test_detect_async_calls(self, image_path, rotation, expected_result):
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(image_path))
|
||||
# Set rotation parameters using ImageProcessingOptions.
|
||||
image_processing_options = _ImageProcessingOptions(
|
||||
rotation_degrees=rotation)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: _HandLandmarkerResult, output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
if result.hand_landmarks and result.hand_world_landmarks and result.handedness:
|
||||
self._assert_actual_result_approximately_matches_expected_result(
|
||||
result, expected_result)
|
||||
else:
|
||||
self.assertEqual(result, expected_result)
|
||||
self.assertTrue(
|
||||
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _HandLandmarkerOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result)
|
||||
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||
for timestamp in range(0, 300, 30):
|
||||
landmarker.detect_async(test_image, timestamp, image_processing_options)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
|
@ -79,6 +79,28 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_embedder",
|
||||
srcs = [
|
||||
"image_embedder.py",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gesture_recognizer",
|
||||
srcs = [
|
||||
|
@ -104,18 +126,19 @@ py_library(
|
|||
)
|
||||
|
||||
py_library(
|
||||
name = "image_embedder",
|
||||
name = "hand_landmarker",
|
||||
srcs = [
|
||||
"image_embedder.py",
|
||||
"hand_landmarker.py",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_py_pb2",
|
||||
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||
"//mediapipe/python:_framework_bindings",
|
||||
"//mediapipe/python:packet_creator",
|
||||
"//mediapipe/python:packet_getter",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
|
||||
"//mediapipe/tasks/python/components/containers:category",
|
||||
"//mediapipe/tasks/python/components/containers:landmark",
|
||||
"//mediapipe/tasks/python/core:base_options",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
"//mediapipe/tasks/python/core:task_info",
|
||||
|
|
|
@ -16,12 +16,17 @@
|
|||
|
||||
import mediapipe.tasks.python.vision.core
|
||||
import mediapipe.tasks.python.vision.gesture_recognizer
|
||||
import mediapipe.tasks.python.vision.hand_landmarker
|
||||
import mediapipe.tasks.python.vision.image_classifier
|
||||
import mediapipe.tasks.python.vision.image_segmenter
|
||||
import mediapipe.tasks.python.vision.object_detector
|
||||
|
||||
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||
GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
|
||||
GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult
|
||||
HandLandmarker = hand_landmarker.HandLandmarker
|
||||
HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
||||
HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
||||
ImageClassifier = image_classifier.ImageClassifier
|
||||
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||
ImageSegmenter = image_segmenter.ImageSegmenter
|
||||
|
@ -33,6 +38,7 @@ RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
|
|||
# Remove unnecessary modules to avoid duplication in API docs.
|
||||
del core
|
||||
del gesture_recognizer
|
||||
del hand_landmarker
|
||||
del image_classifier
|
||||
del image_segmenter
|
||||
del object_detector
|
||||
|
|
|
@ -59,7 +59,7 @@ _GESTURE_DEFAULT_INDEX = -1
|
|||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GestureRecognitionResult:
|
||||
class GestureRecognizerResult:
|
||||
"""The gesture recognition result from GestureRecognizer, where each vector element represents a single hand detected in the image.
|
||||
|
||||
Attributes:
|
||||
|
@ -79,8 +79,8 @@ class GestureRecognitionResult:
|
|||
|
||||
def _build_recognition_result(
|
||||
output_packets: Mapping[str,
|
||||
packet_module.Packet]) -> GestureRecognitionResult:
|
||||
"""Consturcts a `GestureRecognitionResult` from output packets."""
|
||||
packet_module.Packet]) -> GestureRecognizerResult:
|
||||
"""Consturcts a `GestureRecognizerResult` from output packets."""
|
||||
gestures_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
|
@ -122,23 +122,25 @@ def _build_recognition_result(
|
|||
for proto in hand_landmarks_proto_list:
|
||||
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
||||
hand_landmarks.MergeFrom(proto)
|
||||
hand_landmarks_results.append([
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
||||
for hand_landmark in hand_landmarks.landmark
|
||||
])
|
||||
hand_landmarks_list = []
|
||||
for hand_landmark in hand_landmarks.landmark:
|
||||
hand_landmarks_list.append(
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
|
||||
hand_landmarks_results.append(hand_landmarks_list)
|
||||
|
||||
hand_world_landmarks_results = []
|
||||
for proto in hand_world_landmarks_proto_list:
|
||||
hand_world_landmarks = landmark_pb2.LandmarkList()
|
||||
hand_world_landmarks.MergeFrom(proto)
|
||||
hand_world_landmarks_results.append([
|
||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
||||
for hand_world_landmark in hand_world_landmarks.landmark
|
||||
])
|
||||
hand_world_landmarks_list = []
|
||||
for hand_world_landmark in hand_world_landmarks.landmark:
|
||||
hand_world_landmarks_list.append(
|
||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
|
||||
hand_world_landmarks_results.append(hand_world_landmarks_list)
|
||||
|
||||
return GestureRecognitionResult(gesture_results, handedness_results,
|
||||
hand_landmarks_results,
|
||||
hand_world_landmarks_results)
|
||||
return GestureRecognizerResult(gesture_results, handedness_results,
|
||||
hand_landmarks_results,
|
||||
hand_world_landmarks_results)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -183,7 +185,7 @@ class GestureRecognizerOptions:
|
|||
custom_gesture_classifier_options: Optional[
|
||||
_ClassifierOptions] = _ClassifierOptions()
|
||||
result_callback: Optional[Callable[
|
||||
[GestureRecognitionResult, image_module.Image, int], None]] = None
|
||||
[GestureRecognizerResult, image_module.Image, int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
|
||||
|
@ -264,7 +266,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
|
||||
options.result_callback(
|
||||
GestureRecognitionResult([], [], [], []), image,
|
||||
GestureRecognizerResult([], [], [], []), image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
return
|
||||
|
||||
|
@ -299,7 +301,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> GestureRecognitionResult:
|
||||
) -> GestureRecognizerResult:
|
||||
"""Performs hand gesture recognition on the given image.
|
||||
|
||||
Only use this method when the GestureRecognizer is created with the image
|
||||
|
@ -330,7 +332,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
})
|
||||
|
||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
return GestureRecognitionResult([], [], [], [])
|
||||
return GestureRecognizerResult([], [], [], [])
|
||||
|
||||
return _build_recognition_result(output_packets)
|
||||
|
||||
|
@ -339,7 +341,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> GestureRecognitionResult:
|
||||
) -> GestureRecognizerResult:
|
||||
"""Performs gesture recognition on the provided video frame.
|
||||
|
||||
Only use this method when the GestureRecognizer is created with the video
|
||||
|
@ -374,7 +376,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
})
|
||||
|
||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||
return GestureRecognitionResult([], [], [], [])
|
||||
return GestureRecognizerResult([], [], [], [])
|
||||
|
||||
return _build_recognition_result(output_packets)
|
||||
|
||||
|
|
379
mediapipe/tasks/python/vision/hand_landmarker.py
Normal file
379
mediapipe/tasks/python/vision/hand_landmarker.py
Normal file
|
@ -0,0 +1,379 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MediaPipe hand landmarker task."""
|
||||
|
||||
import dataclasses
|
||||
from typing import Callable, Mapping, Optional, List
|
||||
|
||||
from mediapipe.framework.formats import classification_pb2
|
||||
from mediapipe.framework.formats import landmark_pb2
|
||||
from mediapipe.python import packet_creator
|
||||
from mediapipe.python import packet_getter
|
||||
from mediapipe.python._framework_bindings import image as image_module
|
||||
from mediapipe.python._framework_bindings import packet as packet_module
|
||||
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
|
||||
from mediapipe.tasks.python.components.containers import category as category_module
|
||||
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
|
||||
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||
_NORM_RECT_TAG = 'NORM_RECT'
|
||||
_HANDEDNESS_STREAM_NAME = 'handedness'
|
||||
_HANDEDNESS_TAG = 'HANDEDNESS'
|
||||
_HAND_LANDMARKS_STREAM_NAME = 'landmarks'
|
||||
_HAND_LANDMARKS_TAG = 'LANDMARKS'
|
||||
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
|
||||
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
|
||||
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'
|
||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandLandmarkerResult:
|
||||
"""The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image.
|
||||
|
||||
Attributes:
|
||||
handedness: Classification of handedness.
|
||||
hand_landmarks: Detected hand landmarks in normalized image coordinates.
|
||||
hand_world_landmarks: Detected hand landmarks in world coordinates.
|
||||
"""
|
||||
|
||||
handedness: List[List[category_module.Category]]
|
||||
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
||||
|
||||
|
||||
def _build_landmarker_result(
|
||||
output_packets: Mapping[str, packet_module.Packet]) -> HandLandmarkerResult:
|
||||
"""Constructs a `HandLandmarksDetectionResult` from output packets."""
|
||||
handedness_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||
|
||||
handedness_results = []
|
||||
for proto in handedness_proto_list:
|
||||
handedness_categories = []
|
||||
handedness_classifications = classification_pb2.ClassificationList()
|
||||
handedness_classifications.MergeFrom(proto)
|
||||
for handedness in handedness_classifications.classification:
|
||||
handedness_categories.append(
|
||||
category_module.Category(
|
||||
index=handedness.index,
|
||||
score=handedness.score,
|
||||
display_name=handedness.display_name,
|
||||
category_name=handedness.label))
|
||||
handedness_results.append(handedness_categories)
|
||||
|
||||
hand_landmarks_results = []
|
||||
for proto in hand_landmarks_proto_list:
|
||||
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
||||
hand_landmarks.MergeFrom(proto)
|
||||
hand_landmarks_list = []
|
||||
for hand_landmark in hand_landmarks.landmark:
|
||||
hand_landmarks_list.append(
|
||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
|
||||
hand_landmarks_results.append(hand_landmarks_list)
|
||||
|
||||
hand_world_landmarks_results = []
|
||||
for proto in hand_world_landmarks_proto_list:
|
||||
hand_world_landmarks = landmark_pb2.LandmarkList()
|
||||
hand_world_landmarks.MergeFrom(proto)
|
||||
hand_world_landmarks_list = []
|
||||
for hand_world_landmark in hand_world_landmarks.landmark:
|
||||
hand_world_landmarks_list.append(
|
||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
|
||||
hand_world_landmarks_results.append(hand_world_landmarks_list)
|
||||
|
||||
return HandLandmarkerResult(handedness_results, hand_landmarks_results,
|
||||
hand_world_landmarks_results)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandLandmarkerOptions:
|
||||
"""Options for the hand landmarker task.
|
||||
|
||||
Attributes:
|
||||
base_options: Base options for the hand landmarker task.
|
||||
running_mode: The running mode of the task. Default to the image mode.
|
||||
HandLandmarker has three running modes: 1) The image mode for detecting
|
||||
hand landmarks on single image inputs. 2) The video mode for detecting
|
||||
hand landmarks on the decoded frames of a video. 3) The live stream mode
|
||||
for detecting hand landmarks on the live stream of input data, such as
|
||||
from camera. In this mode, the "result_callback" below must be specified
|
||||
to receive the detection results asynchronously.
|
||||
num_hands: The maximum number of hands can be detected by the hand
|
||||
landmarker.
|
||||
min_hand_detection_confidence: The minimum confidence score for the hand
|
||||
detection to be considered successful.
|
||||
min_hand_presence_confidence: The minimum confidence score of hand presence
|
||||
score in the hand landmark detection.
|
||||
min_tracking_confidence: The minimum confidence score for the hand tracking
|
||||
to be considered successful.
|
||||
result_callback: The user-defined result callback for processing live stream
|
||||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
num_hands: Optional[int] = 1
|
||||
min_hand_detection_confidence: Optional[float] = 0.5
|
||||
min_hand_presence_confidence: Optional[float] = 0.5
|
||||
min_tracking_confidence: Optional[float] = 0.5
|
||||
result_callback: Optional[Callable[
|
||||
[HandLandmarkerResult, image_module.Image, int], None]] = None
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _HandLandmarkerGraphOptionsProto:
|
||||
"""Generates an HandLandmarkerGraphOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||
|
||||
# Initialize the hand landmarker options from base options.
|
||||
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
|
||||
base_options=base_options_proto)
|
||||
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
|
||||
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
|
||||
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
|
||||
return hand_landmarker_options_proto
|
||||
|
||||
|
||||
class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
||||
"""Class that performs hand landmarks detection on images."""
|
||||
|
||||
@classmethod
|
||||
def create_from_model_path(cls, model_path: str) -> 'HandLandmarker':
|
||||
"""Creates an `HandLandmarker` object from a TensorFlow Lite model and the default `HandLandmarkerOptions`.
|
||||
|
||||
Note that the created `HandLandmarker` instance is in image mode, for
|
||||
detecting hand landmarks on single image inputs.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
|
||||
Returns:
|
||||
`HandLandmarker` object that's created from the model file and the
|
||||
default `HandLandmarkerOptions`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `HandLandmarker` object from the
|
||||
provided file such as invalid file path.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = HandLandmarkerOptions(
|
||||
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||
return cls.create_from_options(options)
|
||||
|
||||
@classmethod
|
||||
def create_from_options(cls,
|
||||
options: HandLandmarkerOptions) -> 'HandLandmarker':
|
||||
"""Creates the `HandLandmarker` object from hand landmarker options.
|
||||
|
||||
Args:
|
||||
options: Options for the hand landmarker task.
|
||||
|
||||
Returns:
|
||||
`HandLandmarker` object that's created from `options`.
|
||||
|
||||
Raises:
|
||||
ValueError: If failed to create `HandLandmarker` object from
|
||||
`HandLandmarkerOptions` such as missing the model.
|
||||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
|
||||
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
|
||||
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||
empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME]
|
||||
options.result_callback(
|
||||
HandLandmarkerResult([], [], []), image,
|
||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
return
|
||||
|
||||
hand_landmarks_detection_result = _build_landmarker_result(output_packets)
|
||||
timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp
|
||||
options.result_callback(hand_landmarks_detection_result, image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
|
||||
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
|
||||
':'.join([
|
||||
_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME
|
||||
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||
],
|
||||
task_options=options)
|
||||
return cls(
|
||||
task_info.generate_graph_config(
|
||||
enable_flow_limiting=options.running_mode ==
|
||||
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||
packets_callback if options.result_callback else None)
|
||||
|
||||
def detect(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> HandLandmarkerResult:
|
||||
"""Performs hand landmarks detection on the given image.
|
||||
|
||||
Only use this method when the HandLandmarker is created with the image
|
||||
running mode.
|
||||
|
||||
The image can be of any size with format RGB or RGBA.
|
||||
TODO: Describes how the input image will be preprocessed after the yuv
|
||||
support is implemented.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
The hand landmarks detection results.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If hand landmarker detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
output_packets = self._process_image_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||
})
|
||||
|
||||
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||
return HandLandmarkerResult([], [], [])
|
||||
|
||||
return _build_landmarker_result(output_packets)
|
||||
|
||||
def detect_for_video(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> HandLandmarkerResult:
|
||||
"""Performs hand landmarks detection on the provided video frame.
|
||||
|
||||
Only use this method when the HandLandmarker is created with the video
|
||||
running mode.
|
||||
|
||||
Only use this method when the HandLandmarker is created with the video
|
||||
running mode. It's required to provide the video frame's timestamp (in
|
||||
milliseconds) along with the video frame. The input timestamps should be
|
||||
monotonically increasing for adjacent calls of this method.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
The hand landmarks detection results.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the input arguments is invalid.
|
||||
RuntimeError: If hand landmarker detection failed to run.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
output_packets = self._process_video_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
||||
|
||||
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||
return HandLandmarkerResult([], [], [])
|
||||
|
||||
return _build_landmarker_result(output_packets)
|
||||
|
||||
def detect_async(
|
||||
self,
|
||||
image: image_module.Image,
|
||||
timestamp_ms: int,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||
) -> None:
|
||||
"""Sends live image data to perform hand landmarks detection.
|
||||
|
||||
The results will be available via the "result_callback" provided in the
|
||||
HandLandmarkerOptions. Only use this method when the HandLandmarker is
|
||||
created with the live stream running mode.
|
||||
|
||||
Only use this method when the HandLandmarker is created with the live
|
||||
stream running mode. The input timestamps should be monotonically increasing
|
||||
for adjacent calls of this method. This method will return immediately after
|
||||
the input image is accepted. The results will be available via the
|
||||
`result_callback` provided in the `HandLandmarkerOptions`. The
|
||||
`detect_async` method is designed to process live stream data such as
|
||||
camera input. To lower the overall latency, hand landmarker may drop the
|
||||
input images if needed. In other words, it's not guaranteed to have output
|
||||
per input image.
|
||||
|
||||
The `result_callback` provides:
|
||||
- The hand landmarks detection results.
|
||||
- The input image that the hand landmarker runs on.
|
||||
- The input timestamp in milliseconds.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the current input timestamp is smaller than what the
|
||||
hand landmarker has already processed.
|
||||
"""
|
||||
normalized_rect = self.convert_to_normalized_rect(
|
||||
image_processing_options, roi_allowed=False)
|
||||
self._send_live_stream_data({
|
||||
_IMAGE_IN_STREAM_NAME:
|
||||
packet_creator.create_image(image).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
_NORM_RECT_STREAM_NAME:
|
||||
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||
})
|
14
mediapipe/tasks/testdata/metadata/BUILD
vendored
14
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -23,10 +23,13 @@ package(
|
|||
)
|
||||
|
||||
mediapipe_files(srcs = [
|
||||
"30k-clean.model",
|
||||
"bert_text_classifier_no_metadata.tflite",
|
||||
"mobile_ica_8bit-with-metadata.tflite",
|
||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||
"mobilebert_vocab.txt",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v2_1.0_224_quant.tflite",
|
||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||
|
@ -60,11 +63,17 @@ exports_files([
|
|||
"movie_review_labels.txt",
|
||||
"regex_vocab.txt",
|
||||
"movie_review.json",
|
||||
"bert_tokenizer_meta.json",
|
||||
"bert_text_classifier_with_sentence_piece.json",
|
||||
"sentence_piece_tokenizer_meta.json",
|
||||
"bert_text_classifier_with_bert_tokenizer.json",
|
||||
])
|
||||
|
||||
filegroup(
|
||||
name = "model_files",
|
||||
srcs = [
|
||||
"30k-clean.model",
|
||||
"bert_text_classifier_no_metadata.tflite",
|
||||
"mobile_ica_8bit-with-metadata.tflite",
|
||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||
|
@ -81,6 +90,9 @@ filegroup(
|
|||
name = "data_files",
|
||||
srcs = [
|
||||
"associated_file_meta.json",
|
||||
"bert_text_classifier_with_bert_tokenizer.json",
|
||||
"bert_text_classifier_with_sentence_piece.json",
|
||||
"bert_tokenizer_meta.json",
|
||||
"bounding_box_tensor_meta.json",
|
||||
"classification_tensor_float_meta.json",
|
||||
"classification_tensor_uint8_meta.json",
|
||||
|
@ -96,6 +108,7 @@ filegroup(
|
|||
"input_text_tensor_default_meta.json",
|
||||
"input_text_tensor_meta.json",
|
||||
"labels.txt",
|
||||
"mobilebert_vocab.txt",
|
||||
"mobilenet_v2_1.0_224.json",
|
||||
"mobilenet_v2_1.0_224_quant.json",
|
||||
"movie_review.json",
|
||||
|
@ -105,5 +118,6 @@ filegroup(
|
|||
"score_calibration_file_meta.json",
|
||||
"score_calibration_tensor_meta.json",
|
||||
"score_thresholding_meta.json",
|
||||
"sentence_piece_tokenizer_meta.json",
|
||||
],
|
||||
)
|
||||
|
|
84
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_bert_tokenizer.json
vendored
Normal file
84
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_bert_tokenizer.json
vendored
Normal file
|
@ -0,0 +1,84 @@
|
|||
{
|
||||
"name": "TextClassifier",
|
||||
"description": "Classify the input text into a set of known categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "ids",
|
||||
"description": "Tokenized ids of the input text.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "BertTokenizerOptions",
|
||||
"options": {
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "mobilebert_vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.1.0"
|
||||
}
|
83
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_sentence_piece.json
vendored
Normal file
83
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_sentence_piece.json
vendored
Normal file
|
@ -0,0 +1,83 @@
|
|||
{
|
||||
"name": "TextClassifier",
|
||||
"description": "Classify the input text into a set of known categories.",
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_tensor_metadata": [
|
||||
{
|
||||
"name": "ids",
|
||||
"description": "Tokenized ids of the input text.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "segment_ids",
|
||||
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "mask",
|
||||
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
}
|
||||
}
|
||||
],
|
||||
"output_tensor_metadata": [
|
||||
{
|
||||
"name": "score",
|
||||
"description": "Score of the labels respectively.",
|
||||
"content": {
|
||||
"content_properties_type": "FeatureProperties",
|
||||
"content_properties": {
|
||||
}
|
||||
},
|
||||
"stats": {
|
||||
"max": [
|
||||
1.0
|
||||
],
|
||||
"min": [
|
||||
0.0
|
||||
]
|
||||
},
|
||||
"associated_files": [
|
||||
{
|
||||
"name": "labels.txt",
|
||||
"description": "Labels for categories that the model can recognize.",
|
||||
"type": "TENSOR_AXIS_LABELS"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "SentencePieceTokenizerOptions",
|
||||
"options": {
|
||||
"sentencePiece_model": [
|
||||
{
|
||||
"name": "30k-clean.model",
|
||||
"description": "The sentence piece model file."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"min_parser_version": "1.1.0"
|
||||
}
|
20
mediapipe/tasks/testdata/metadata/bert_tokenizer_meta.json
vendored
Normal file
20
mediapipe/tasks/testdata/metadata/bert_tokenizer_meta.json
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "BertTokenizerOptions",
|
||||
"options": {
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
30522
mediapipe/tasks/testdata/metadata/mobilebert_vocab.txt
vendored
Normal file
30522
mediapipe/tasks/testdata/metadata/mobilebert_vocab.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
26
mediapipe/tasks/testdata/metadata/sentence_piece_tokenizer_meta.json
vendored
Normal file
26
mediapipe/tasks/testdata/metadata/sentence_piece_tokenizer_meta.json
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"subgraph_metadata": [
|
||||
{
|
||||
"input_process_units": [
|
||||
{
|
||||
"options_type": "SentencePieceTokenizerOptions",
|
||||
"options": {
|
||||
"sentencePiece_model": [
|
||||
{
|
||||
"name": "sp.model",
|
||||
"description": "The sentence piece model file."
|
||||
}
|
||||
],
|
||||
"vocab_file": [
|
||||
{
|
||||
"name": "vocab.txt",
|
||||
"description": "Vocabulary file to convert natural language words to embedding vectors. This file is optional during tokenization, while the sentence piece model is mandatory.",
|
||||
"type": "VOCABULARY"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
14
mediapipe/tasks/testdata/vision/BUILD
vendored
14
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -143,20 +143,6 @@ filegroup(
|
|||
],
|
||||
)
|
||||
|
||||
# Gestures related models. Visible to model_maker.
|
||||
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
|
||||
filegroup(
|
||||
name = "test_gesture_models",
|
||||
srcs = [
|
||||
"hand_landmark_full.tflite",
|
||||
"palm_detection_full.tflite",
|
||||
],
|
||||
visibility = [
|
||||
"//mediapipe/model_maker:__subpackages__",
|
||||
"//mediapipe/tasks:internal",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "test_protos",
|
||||
srcs = [
|
||||
|
|
|
@ -3,9 +3,22 @@
|
|||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
|
||||
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
mediapipe_files(srcs = [
|
||||
"wasm/audio_wasm_internal.js",
|
||||
"wasm/audio_wasm_internal.wasm",
|
||||
"wasm/text_wasm_internal.js",
|
||||
"wasm/text_wasm_internal.wasm",
|
||||
"wasm/vision_wasm_internal.js",
|
||||
"wasm/vision_wasm_internal.wasm",
|
||||
])
|
||||
|
||||
# Audio
|
||||
|
||||
mediapipe_ts_library(
|
||||
|
@ -28,15 +41,18 @@ rollup_bundle(
|
|||
|
||||
pkg_npm(
|
||||
name = "audio_pkg",
|
||||
package_name = "__PACKAGE_NAME__",
|
||||
package_name = "@mediapipe/tasks-__NAME__",
|
||||
srcs = ["package.json"],
|
||||
substitutions = {
|
||||
"__PACKAGE_NAME__": "@mediapipe/tasks-audio",
|
||||
"__NAME__": "audio",
|
||||
"__DESCRIPTION__": "MediaPipe Audio Tasks",
|
||||
"__BUNDLE__": "audio_bundle.js",
|
||||
},
|
||||
tgz = "audio.tgz",
|
||||
deps = [":audio_bundle"],
|
||||
deps = [
|
||||
"wasm/audio_wasm_internal.js",
|
||||
"wasm/audio_wasm_internal.wasm",
|
||||
":audio_bundle",
|
||||
],
|
||||
)
|
||||
|
||||
# Text
|
||||
|
@ -61,15 +77,18 @@ rollup_bundle(
|
|||
|
||||
pkg_npm(
|
||||
name = "text_pkg",
|
||||
package_name = "__PACKAGE_NAME__",
|
||||
package_name = "@mediapipe/tasks-__NAME__",
|
||||
srcs = ["package.json"],
|
||||
substitutions = {
|
||||
"__PACKAGE_NAME__": "@mediapipe/tasks-text",
|
||||
"__NAME__": "text",
|
||||
"__DESCRIPTION__": "MediaPipe Text Tasks",
|
||||
"__BUNDLE__": "text_bundle.js",
|
||||
},
|
||||
tgz = "text.tgz",
|
||||
deps = [":text_bundle"],
|
||||
deps = [
|
||||
"wasm/text_wasm_internal.js",
|
||||
"wasm/text_wasm_internal.wasm",
|
||||
":text_bundle",
|
||||
],
|
||||
)
|
||||
|
||||
# Vision
|
||||
|
@ -94,13 +113,16 @@ rollup_bundle(
|
|||
|
||||
pkg_npm(
|
||||
name = "vision_pkg",
|
||||
package_name = "__PACKAGE_NAME__",
|
||||
package_name = "@mediapipe/tasks-__NAME__",
|
||||
srcs = ["package.json"],
|
||||
substitutions = {
|
||||
"__PACKAGE_NAME__": "@mediapipe/tasks-vision",
|
||||
"__NAME__": "vision",
|
||||
"__DESCRIPTION__": "MediaPipe Vision Tasks",
|
||||
"__BUNDLE__": "vision_bundle.js",
|
||||
},
|
||||
tgz = "vision.tgz",
|
||||
deps = [":vision_bundle"],
|
||||
tgz = "vision_pkg.tgz",
|
||||
deps = [
|
||||
"wasm/vision_wasm_internal.js",
|
||||
"wasm/vision_wasm_internal.wasm",
|
||||
":vision_bundle",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
|||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||
|
|
|
@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {AudioClassifierOptions} from './audio_classifier_options';
|
||||
import {Classifications} from './audio_classifier_result';
|
||||
import {AudioClassifierResult} from './audio_classifier_result';
|
||||
|
||||
const MEDIAPIPE_GRAPH =
|
||||
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||
|
@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH =
|
|||
// implementation
|
||||
const AUDIO_STREAM = 'input_audio';
|
||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
||||
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/** Performs audio classification. */
|
||||
export class AudioClassifier extends TaskRunner {
|
||||
private classifications: Classifications[] = [];
|
||||
private classificationResults: AudioClassifierResult[] = [];
|
||||
private defaultSampleRate = 48000;
|
||||
private readonly options = new AudioClassifierGraphOptions();
|
||||
|
||||
|
@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner {
|
|||
* `48000` if no custom default was set.
|
||||
* @return The classification result of the audio datas
|
||||
*/
|
||||
classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
|
||||
classify(audioData: Float32Array, sampleRate?: number):
|
||||
AudioClassifierResult[] {
|
||||
sampleRate = sampleRate ?? this.defaultSampleRate;
|
||||
|
||||
// Configures the number of samples in the WASM layer. We re-configure the
|
||||
|
@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner {
|
|||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
||||
this.addAudioToStream(audioData, timestamp);
|
||||
|
||||
this.classifications = [];
|
||||
this.classificationResults = [];
|
||||
this.finishProcessing();
|
||||
return [...this.classifications];
|
||||
return [...this.classificationResults];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal function for converting raw data into a classification, and
|
||||
* adding it to our classfications list.
|
||||
* Internal function for converting raw data into classification results, and
|
||||
* adding them to our classfication results list.
|
||||
**/
|
||||
private addJsAudioClassification(binaryProto: Uint8Array): void {
|
||||
const classificationResult =
|
||||
ClassificationResult.deserializeBinary(binaryProto);
|
||||
this.classifications.push(
|
||||
...convertFromClassificationResultProto(classificationResult));
|
||||
private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void {
|
||||
binaryProtos.forEach(binaryProto => {
|
||||
const classificationResult =
|
||||
ClassificationResult.deserializeBinary(binaryProto);
|
||||
this.classificationResults.push(
|
||||
convertFromClassificationResultProto(classificationResult));
|
||||
});
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
|
@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner {
|
|||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(AUDIO_STREAM);
|
||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||
graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -198,14 +201,15 @@ export class AudioClassifier extends TaskRunner {
|
|||
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
||||
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
||||
classifierNode.addOutputStream(
|
||||
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM);
|
||||
'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||
classifierNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||
this.addJsAudioClassification(binaryProto);
|
||||
});
|
||||
this.attachProtoVectorListener(
|
||||
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
|
||||
this.addJsAudioClassificationResults(binaryProtos);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -15,4 +15,4 @@
|
|||
*/
|
||||
|
||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
|
|
@ -10,8 +10,8 @@ mediapipe_ts_library(
|
|||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "classifications",
|
||||
srcs = ["classifications.d.ts"],
|
||||
name = "classification_result",
|
||||
srcs = ["classification_result.d.ts"],
|
||||
deps = [":category"],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,27 +16,14 @@
|
|||
|
||||
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||
|
||||
/** List of predicted categories with an optional timestamp. */
|
||||
export interface ClassificationEntry {
|
||||
/** Classification results for a given classifier head. */
|
||||
export interface Classifications {
|
||||
/**
|
||||
* The array of predicted categories, usually sorted by descending scores,
|
||||
* e.g., from high to low probability.
|
||||
*/
|
||||
categories: Category[];
|
||||
|
||||
/**
|
||||
* The optional timestamp (in milliseconds) associated to the classification
|
||||
* entry. This is useful for time series use cases, e.g., audio
|
||||
* classification.
|
||||
*/
|
||||
timestampMs?: number;
|
||||
}
|
||||
|
||||
/** Classifications for a given classifier head. */
|
||||
export interface Classifications {
|
||||
/** A list of classification entries. */
|
||||
entries: ClassificationEntry[];
|
||||
|
||||
/**
|
||||
* The index of the classifier head these categories refer to. This is
|
||||
* useful for multi-head models.
|
||||
|
@ -45,7 +32,24 @@ export interface Classifications {
|
|||
|
||||
/**
|
||||
* The name of the classifier head, which is the corresponding tensor
|
||||
* metadata name.
|
||||
* metadata name. Defaults to an empty string if there is no such metadata.
|
||||
*/
|
||||
headName: string;
|
||||
}
|
||||
|
||||
/** Classification results of a model. */
|
||||
export interface ClassificationResult {
|
||||
/** The classification results for each head of the model. */
|
||||
classifications: Classifications[];
|
||||
|
||||
/**
|
||||
* The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
* corresponding to these results.
|
||||
*
|
||||
* This is only used for classification on time series (e.g. audio
|
||||
* classification). In these use cases, the amount of data to process might
|
||||
* exceed the maximum size that the model can process: to solve this, the
|
||||
* input data is split into multiple chunks starting at different timestamps.
|
||||
*/
|
||||
timestampMs?: number;
|
||||
}
|
|
@ -17,8 +17,9 @@ mediapipe_ts_library(
|
|||
name = "classifier_result",
|
||||
srcs = ["classifier_result.ts"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -14,48 +14,46 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
||||
const DEFAULT_INDEX = -1;
|
||||
const DEFAULT_SCORE = 0.0;
|
||||
|
||||
/**
|
||||
* Converts a ClassificationEntry proto to the ClassificationEntry result
|
||||
* type.
|
||||
* Converts a Classifications proto to a Classifications object.
|
||||
*/
|
||||
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
|
||||
ClassificationEntry {
|
||||
const categories = source.getCategoriesList().map(category => {
|
||||
return {
|
||||
index: category.getIndex() ?? DEFAULT_INDEX,
|
||||
score: category.getScore() ?? DEFAULT_SCORE,
|
||||
displayName: category.getDisplayName() ?? '',
|
||||
categoryName: category.getCategoryName() ?? '',
|
||||
};
|
||||
});
|
||||
|
||||
function convertFromClassificationsProto(source: ClassificationsProto):
|
||||
Classifications {
|
||||
const categories =
|
||||
source.getClassificationList()?.getClassificationList().map(
|
||||
classification => {
|
||||
return {
|
||||
index: classification.getIndex() ?? DEFAULT_INDEX,
|
||||
score: classification.getScore() ?? DEFAULT_SCORE,
|
||||
categoryName: classification.getLabel() ?? '',
|
||||
displayName: classification.getDisplayName() ?? '',
|
||||
};
|
||||
}) ??
|
||||
[];
|
||||
return {
|
||||
categories,
|
||||
timestampMs: source.getTimestampMs(),
|
||||
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
|
||||
headName: source.getHeadName() ?? '',
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a ClassificationResult proto to a list of classifications.
|
||||
* Converts a ClassificationResult proto to a ClassificationResult object.
|
||||
*/
|
||||
export function convertFromClassificationResultProto(
|
||||
classificationResult: ClassificationResult) : Classifications[] {
|
||||
const result: Classifications[] = [];
|
||||
for (const classificationsProto of
|
||||
classificationResult.getClassificationsList()) {
|
||||
const classifications: Classifications = {
|
||||
entries: classificationsProto.getEntriesList().map(
|
||||
entry => convertFromClassificationEntryProto(entry)),
|
||||
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
|
||||
headName: classificationsProto.getHeadName() ?? '',
|
||||
};
|
||||
result.push(classifications);
|
||||
source: ClassificationResultProto): ClassificationResult {
|
||||
const result: ClassificationResult = {
|
||||
classifications: source.getClassificationsList().map(
|
||||
classififications => convertFromClassificationsProto(classififications))
|
||||
};
|
||||
if (source.hasTimestampMs()) {
|
||||
result.timestampMs = source.getTimestampMs();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
{
|
||||
"name": "__PACKAGE_NAME__",
|
||||
"name": "@mediapipe/tasks-__NAME__",
|
||||
"version": "__VERSION__",
|
||||
"description": "__DESCRIPTION__",
|
||||
"main": "__BUNDLE__",
|
||||
"module": "__BUNDLE__",
|
||||
"main": "__NAME__bundle.js",
|
||||
"module": "__NAME__bundle.js",
|
||||
"exports": {
|
||||
".": "./__NAME__bundle.js",
|
||||
"./loader": "./wasm/__NAME__wasm_internal.js",
|
||||
"./wasm": "./wasm/__NAME__wasm_internal.wasm"
|
||||
},
|
||||
"author": "mediapipe@google.com",
|
||||
"license": "Apache-2.0",
|
||||
"type": "module",
|
||||
|
|
|
@ -22,7 +22,7 @@ mediapipe_ts_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||
|
|
|
@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {TextClassifierOptions} from './text_classifier_options';
|
||||
import {Classifications} from './text_classifier_result';
|
||||
import {TextClassifierResult} from './text_classifier_result';
|
||||
|
||||
const INPUT_STREAM = 'text_in';
|
||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
|
||||
const CLASSIFICATIONS_STREAM = 'classifications_out';
|
||||
const TEXT_CLASSIFIER_GRAPH =
|
||||
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||
|
||||
|
@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH =
|
|||
|
||||
/** Performs Natural Language classification. */
|
||||
export class TextClassifier extends TaskRunner {
|
||||
private classifications: Classifications[] = [];
|
||||
private classificationResult: TextClassifierResult = {classifications: []};
|
||||
private readonly options = new TextClassifierGraphOptions();
|
||||
|
||||
/**
|
||||
|
@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner {
|
|||
* @param text The text to process.
|
||||
* @return The classification result of the text
|
||||
*/
|
||||
classify(text: string): Classifications[] {
|
||||
// Get classification classes by running our MediaPipe graph.
|
||||
this.classifications = [];
|
||||
classify(text: string): TextClassifierResult {
|
||||
// Get classification result by running our MediaPipe graph.
|
||||
this.classificationResult = {classifications: []};
|
||||
this.addStringToStream(
|
||||
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
||||
this.finishProcessing();
|
||||
return [...this.classifications];
|
||||
}
|
||||
|
||||
// Internal function for converting raw data into a classification, and
|
||||
// adding it to our classifications list.
|
||||
private addJsTextClassification(binaryProto: Uint8Array): void {
|
||||
const classificationResult =
|
||||
ClassificationResult.deserializeBinary(binaryProto);
|
||||
console.log(classificationResult.toObject());
|
||||
this.classifications.push(
|
||||
...convertFromClassificationResultProto(classificationResult));
|
||||
return this.classificationResult;
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner {
|
|||
const classifierNode = new CalculatorGraphConfig.Node();
|
||||
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
||||
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
||||
classifierNode.addOutputStream(
|
||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||
classifierNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||
this.addJsTextClassification(binaryProto);
|
||||
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
|
|
|
@ -15,4 +15,4 @@
|
|||
*/
|
||||
|
||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
|
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/components/containers:category",
|
||||
"//mediapipe/tasks/web/components/containers:classifications",
|
||||
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||
"//mediapipe/tasks/web/components/processors:base_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||
|
|
|
@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {ImageClassifierOptions} from './image_classifier_options';
|
||||
import {Classifications} from './image_classifier_result';
|
||||
import {ImageClassifierResult} from './image_classifier_result';
|
||||
|
||||
const IMAGE_CLASSIFIER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
||||
const INPUT_STREAM = 'input_image';
|
||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
||||
const CLASSIFICATIONS_STREAM = 'classifications';
|
||||
|
||||
export {ImageSource}; // Used in the public API
|
||||
|
||||
|
@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API
|
|||
|
||||
/** Performs classification on images. */
|
||||
export class ImageClassifier extends TaskRunner {
|
||||
private classifications: Classifications[] = [];
|
||||
private classificationResult: ImageClassifierResult = {classifications: []};
|
||||
private readonly options = new ImageClassifierGraphOptions();
|
||||
|
||||
/**
|
||||
|
@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner {
|
|||
* provided, defaults to `performance.now()`.
|
||||
* @return The classification result of the image
|
||||
*/
|
||||
classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
|
||||
// Get classification classes by running our MediaPipe graph.
|
||||
this.classifications = [];
|
||||
classify(imageSource: ImageSource, timestamp?: number):
|
||||
ImageClassifierResult {
|
||||
// Get classification result by running our MediaPipe graph.
|
||||
this.classificationResult = {classifications: []};
|
||||
this.addGpuBufferAsImageToStream(
|
||||
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||
this.finishProcessing();
|
||||
return [...this.classifications];
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal function for converting raw data into a classification, and
|
||||
* adding it to our classfications list.
|
||||
**/
|
||||
private addJsImageClassification(binaryProto: Uint8Array): void {
|
||||
const classificationResult =
|
||||
ClassificationResult.deserializeBinary(binaryProto);
|
||||
this.classifications.push(
|
||||
...convertFromClassificationResultProto(classificationResult));
|
||||
return this.classificationResult;
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
private refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(INPUT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
||||
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner {
|
|||
const classifierNode = new CalculatorGraphConfig.Node();
|
||||
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
|
||||
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
||||
classifierNode.addOutputStream(
|
||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
||||
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||
classifierNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(classifierNode);
|
||||
|
||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
||||
this.addJsImageClassification(binaryProto);
|
||||
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||
this.classificationResult = convertFromClassificationResultProto(
|
||||
ClassificationResult.deserializeBinary(binaryProto));
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
|
|
|
@ -15,4 +15,4 @@
|
|||
*/
|
||||
|
||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
||||
export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||
|
|
38
third_party/external_files.bzl
vendored
38
third_party/external_files.bzl
vendored
|
@ -28,12 +28,36 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_no_metadata_tflite",
|
||||
sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_no_metadata.tflite?generation=1667948360250899"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_tflite",
|
||||
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_with_bert_tokenizer_json",
|
||||
sha256 = "49f148a13a4e3b486b1d3c2400e46e5ebd0d375674c0154278b835760e873a95",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_bert_tokenizer.json?generation=1667948363241334"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_text_classifier_with_sentence_piece_json",
|
||||
sha256 = "113091f3892691de57e379387256b2ce0cc18a1b5185af866220a46da8221f26",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_sentence_piece.json?generation=1667948366009530"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bert_tokenizer_meta_json",
|
||||
sha256 = "116d70c7c3ef413a8bff54ab758f9ed3d6e51fdc5621d8c920ad2f0035831804",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_tokenizer_meta.json?generation=1667948368809108"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_bounding_box_tensor_meta_json",
|
||||
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
|
||||
|
@ -403,7 +427,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_labels_txt",
|
||||
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667888034706429"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667892497527642"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -553,13 +577,13 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_movie_review_json",
|
||||
sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667888039053188"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667892501695336"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_movie_review_labels_txt",
|
||||
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667888041670721"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667892504334882"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -703,7 +727,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_regex_vocab_txt",
|
||||
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667888047885461"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667892507770551"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -790,6 +814,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_sentence_piece_tokenizer_meta_json",
|
||||
sha256 = "416bfe231710502e4a93e1b1950c0c6e5db49cffb256d241ef3d3f2d0d57718b",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/sentence_piece_tokenizer_meta.json?generation=1667948375508564"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_speech_16000_hz_mono_wav",
|
||||
sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6",
|
||||
|
|
47
third_party/wasm_files.bzl
vendored
Normal file
47
third_party/wasm_files.bzl
vendored
Normal file
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
WASM dependencies for MediaPipe.
|
||||
|
||||
This file is auto-generated.
|
||||
"""
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file")
|
||||
|
||||
# buildifier: disable=unnamed-macro
|
||||
def wasm_files():
|
||||
"""WASM dependencies for MediaPipe."""
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
|
||||
sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_text_wasm_internal_js",
|
||||
sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
|
||||
sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
|
||||
sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
|
||||
sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
|
||||
sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"],
|
||||
)
|
Loading…
Reference in New Issue
Block a user