Merge branch 'master' into image-embedder-python
This commit is contained in:
commit
0a6e21c212
|
@ -546,3 +546,6 @@ rules_proto_toolchains()
|
||||||
|
|
||||||
load("//third_party:external_files.bzl", "external_files")
|
load("//third_party:external_files.bzl", "external_files")
|
||||||
external_files()
|
external_files()
|
||||||
|
|
||||||
|
load("//third_party:wasm_files.bzl", "wasm_files")
|
||||||
|
wasm_files()
|
||||||
|
|
|
@ -200,3 +200,38 @@ cc_test(
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "embedding_aggregation_calculator",
|
||||||
|
srcs = ["embedding_aggregation_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "embedding_aggregation_calculator_test",
|
||||||
|
srcs = ["embedding_aggregation_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":embedding_aggregation_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:output_stream_poller",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,132 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/api2/packet.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace api2 {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
|
||||||
|
// Aggregates EmbeddingResult packets into a vector of timestamped
|
||||||
|
// EmbeddingResult. Acts as a pass-through if no timestamp aggregation is
|
||||||
|
// needed.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// EMBEDDINGS: EmbeddingResult
|
||||||
|
// The EmbeddingResult packets to aggregate.
|
||||||
|
// TIMESTAMPS: std::vector<Timestamp> @Optional.
|
||||||
|
// The collection of timestamps that this calculator should aggregate. This
|
||||||
|
// stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS output
|
||||||
|
// will contain the aggregated results. Otherwise as no timestamp
|
||||||
|
// aggregation is required the EMBEDDINGS output is used to pass the inputs
|
||||||
|
// EmbeddingResults unchanged.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// EMBEDDINGS: EmbeddingResult @Optional
|
||||||
|
// The input EmbeddingResult, unchanged. Must be connected if the TIMESTAMPS
|
||||||
|
// input is not connected, as it signals that timestamp aggregation is not
|
||||||
|
// required.
|
||||||
|
// TIMESTAMPED_EMBEDDINGS: std::vector<EmbeddingResult> @Optional
|
||||||
|
// The embedding results aggregated by timestamp. Must be connected if the
|
||||||
|
// TIMESTAMPS input is connected as it signals that timestamp aggregation is
|
||||||
|
// required.
|
||||||
|
//
|
||||||
|
// Example without timestamp aggregation (pass-through):
|
||||||
|
// node {
|
||||||
|
// calculator: "EmbeddingAggregationCalculator"
|
||||||
|
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||||
|
// output_stream: "EMBEDDINGS:embeddings_out"
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Example with timestamp aggregation:
|
||||||
|
// node {
|
||||||
|
// calculator: "EmbeddingAggregationCalculator"
|
||||||
|
// input_stream: "EMBEDDINGS:embeddings_in"
|
||||||
|
// input_stream: "TIMESTAMPS:timestamps_in"
|
||||||
|
// output_stream: "TIMESTAMPED_EMBEDDINGS:timestamped_embeddings_out"
|
||||||
|
// }
|
||||||
|
class EmbeddingAggregationCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<EmbeddingResult> kEmbeddingsIn{"EMBEDDINGS"};
|
||||||
|
static constexpr Input<std::vector<Timestamp>>::Optional kTimestampsIn{
|
||||||
|
"TIMESTAMPS"};
|
||||||
|
static constexpr Output<EmbeddingResult>::Optional kEmbeddingsOut{
|
||||||
|
"EMBEDDINGS"};
|
||||||
|
static constexpr Output<std::vector<EmbeddingResult>>::Optional
|
||||||
|
kTimestampedEmbeddingsOut{"TIMESTAMPED_EMBEDDINGS"};
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kEmbeddingsIn, kTimestampsIn, kEmbeddingsOut,
|
||||||
|
kTimestampedEmbeddingsOut);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||||
|
absl::Status Open(CalculatorContext* cc);
|
||||||
|
absl::Status Process(CalculatorContext* cc);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool time_aggregation_enabled_;
|
||||||
|
std::unordered_map<int64, EmbeddingResult> cached_embeddings_;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::Status EmbeddingAggregationCalculator::UpdateContract(
|
||||||
|
CalculatorContract* cc) {
|
||||||
|
if (kTimestampsIn(cc).IsConnected()) {
|
||||||
|
RET_CHECK(kTimestampedEmbeddingsOut(cc).IsConnected());
|
||||||
|
} else {
|
||||||
|
RET_CHECK(kEmbeddingsOut(cc).IsConnected());
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status EmbeddingAggregationCalculator::Open(CalculatorContext* cc) {
|
||||||
|
time_aggregation_enabled_ = kTimestampsIn(cc).IsConnected();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status EmbeddingAggregationCalculator::Process(CalculatorContext* cc) {
|
||||||
|
if (time_aggregation_enabled_) {
|
||||||
|
cached_embeddings_[cc->InputTimestamp().Value()] =
|
||||||
|
std::move(*kEmbeddingsIn(cc));
|
||||||
|
if (kTimestampsIn(cc).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
auto timestamps = kTimestampsIn(cc).Get();
|
||||||
|
std::vector<EmbeddingResult> results;
|
||||||
|
results.reserve(timestamps.size());
|
||||||
|
for (const auto& timestamp : timestamps) {
|
||||||
|
auto& result = cached_embeddings_[timestamp.Value()];
|
||||||
|
result.set_timestamp_ms((timestamp.Value() - timestamps[0].Value()) /
|
||||||
|
1000);
|
||||||
|
results.push_back(std::move(result));
|
||||||
|
cached_embeddings_.erase(timestamp.Value());
|
||||||
|
}
|
||||||
|
kTimestampedEmbeddingsOut(cc).Send(std::move(results));
|
||||||
|
} else {
|
||||||
|
kEmbeddingsOut(cc).Send(kEmbeddingsIn(cc));
|
||||||
|
}
|
||||||
|
RET_CHECK(cached_embeddings_.empty());
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
MEDIAPIPE_REGISTER_NODE(EmbeddingAggregationCalculator);
|
||||||
|
|
||||||
|
} // namespace api2
|
||||||
|
} // namespace mediapipe
|
|
@ -0,0 +1,158 @@
|
||||||
|
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/output_stream_poller.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::ParseTextProtoOrDie;
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
|
using ::testing::Pointwise;
|
||||||
|
|
||||||
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kEmbeddingsInName[] = "embeddings_in";
|
||||||
|
constexpr char kEmbeddingsOutName[] = "embeddings_out";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
constexpr char kTimestampsName[] = "timestamps_in";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out";
|
||||||
|
|
||||||
|
class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test {
|
||||||
|
protected:
|
||||||
|
absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) {
|
||||||
|
Graph graph;
|
||||||
|
auto& calculator = graph.AddNode("EmbeddingAggregationCalculator");
|
||||||
|
graph[Input<EmbeddingResult>(kEmbeddingsTag)].SetName(kEmbeddingsInName) >>
|
||||||
|
calculator.In(kEmbeddingsTag);
|
||||||
|
if (connect_timestamps) {
|
||||||
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||||
|
kTimestampsName) >>
|
||||||
|
calculator.In(kTimestampsTag);
|
||||||
|
calculator.Out(kTimestampedEmbeddingsTag)
|
||||||
|
.SetName(kTimestampedEmbeddingsName) >>
|
||||||
|
graph[Output<std::vector<EmbeddingResult>>(
|
||||||
|
kTimestampedEmbeddingsTag)];
|
||||||
|
} else {
|
||||||
|
calculator.Out(kEmbeddingsTag).SetName(kEmbeddingsOutName) >>
|
||||||
|
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||||
|
if (connect_timestamps) {
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kTimestampedEmbeddingsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kEmbeddingsOutName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Send(
|
||||||
|
const EmbeddingResult& embeddings, int timestamp = 0,
|
||||||
|
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kEmbeddingsInName, MakePacket<EmbeddingResult>(std::move(embeddings))
|
||||||
|
.At(Timestamp(timestamp))));
|
||||||
|
if (aggregation_timestamps.has_value()) {
|
||||||
|
auto packet = std::make_unique<std::vector<Timestamp>>();
|
||||||
|
for (const auto& timestamp : *aggregation_timestamps) {
|
||||||
|
packet->emplace_back(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||||
|
|
||||||
|
Packet packet;
|
||||||
|
if (!poller.Next(&packet)) {
|
||||||
|
return absl::InternalError("Unable to get output packet");
|
||||||
|
}
|
||||||
|
auto result = packet.Get<T>();
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CalculatorGraph calculator_graph_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithoutTimestamps) {
|
||||||
|
EmbeddingResult embedding = ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 0 })pb");
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||||
|
BuildGraph(/*connect_timestamps=*/false));
|
||||||
|
MP_ASSERT_OK(Send(embedding));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto result, GetResult<EmbeddingResult>(poller));
|
||||||
|
|
||||||
|
EXPECT_THAT(result, EqualsProto(embedding));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbeddingAggregationCalculatorTest, SucceedsWithTimestamps) {
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(/*connect_timestamps=*/true));
|
||||||
|
MP_ASSERT_OK(Send(ParseTextProtoOrDie<EmbeddingResult>(R"pb(embeddings {
|
||||||
|
head_index: 0
|
||||||
|
})pb")));
|
||||||
|
MP_ASSERT_OK(Send(
|
||||||
|
ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 1 })pb"),
|
||||||
|
/*timestamp=*/1000,
|
||||||
|
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000})));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||||
|
|
||||||
|
EXPECT_THAT(results,
|
||||||
|
Pointwise(EqualsProto(), {ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 0 }
|
||||||
|
timestamp_ms: 0)pb"),
|
||||||
|
ParseTextProtoOrDie<EmbeddingResult>(
|
||||||
|
R"pb(embeddings { head_index: 1 }
|
||||||
|
timestamp_ms: 1)pb")}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -30,15 +30,6 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "hand_landmarks_detection_result",
|
|
||||||
hdrs = ["hand_landmarks_detection_result.h"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
|
||||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "category",
|
name = "category",
|
||||||
srcs = ["category.cc"],
|
srcs = ["category.cc"],
|
||||||
|
|
|
@ -82,6 +82,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/framework/tool:options_map",
|
"//mediapipe/framework/tool:options_map",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:embedding_aggregation_calculator",
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator",
|
||||||
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
"//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||||
|
|
|
@ -56,6 +56,14 @@ using TensorsSource =
|
||||||
|
|
||||||
constexpr char kTensorsTag[] = "TENSORS";
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
|
||||||
|
// Struct holding the different output streams produced by the graph.
|
||||||
|
struct EmbeddingPostprocessingOutputStreams {
|
||||||
|
Source<EmbeddingResult> embeddings;
|
||||||
|
Source<std::vector<EmbeddingResult>> timestamped_embeddings;
|
||||||
|
};
|
||||||
|
|
||||||
// Identifies whether or not the model has quantized outputs, and performs
|
// Identifies whether or not the model has quantized outputs, and performs
|
||||||
// sanity checks.
|
// sanity checks.
|
||||||
|
@ -168,27 +176,39 @@ absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// The output tensors of an InferenceCalculator, to convert into
|
// The output tensors of an InferenceCalculator, to convert into
|
||||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||||
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
|
// The collection of the timestamps that this calculator should aggregate.
|
||||||
|
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||||
|
// output is used for results. Otherwise as no timestamp aggregation is
|
||||||
|
// required the EMBEDDINGS output is used for results.
|
||||||
|
//
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDING_RESULT - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult @Optional
|
||||||
// The output EmbeddingResult.
|
// The embedding results aggregated by head. Must be connected if the
|
||||||
|
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||||
|
// aggregation is not required.
|
||||||
|
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||||
|
// The embedding result aggregated by timestamp, then by head. Must be
|
||||||
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
|
// timestamp aggregation is required.
|
||||||
//
|
//
|
||||||
// The recommended way of using this graph is through the GraphBuilder API using
|
// The recommended way of using this graph is through the GraphBuilder API using
|
||||||
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
// the 'ConfigureEmbeddingPostprocessing()' function. See header file for more
|
||||||
// details.
|
// details.
|
||||||
//
|
|
||||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
|
||||||
// embeddings aggregation.
|
|
||||||
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||||
mediapipe::SubgraphContext* sc) override {
|
mediapipe::SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto embedding_result_out,
|
auto output_streams,
|
||||||
BuildEmbeddingPostprocessing(
|
BuildEmbeddingPostprocessing(
|
||||||
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
sc->Options<proto::EmbeddingPostprocessingGraphOptions>(),
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph));
|
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||||
embedding_result_out >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||||
|
output_streams.embeddings >> graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
|
output_streams.timestamped_embeddings >>
|
||||||
|
graph[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)];
|
||||||
return graph.GetConfig();
|
return graph.GetConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -200,10 +220,14 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
//
|
//
|
||||||
// options: the on-device EmbeddingPostprocessingGraphOptions
|
// options: the on-device EmbeddingPostprocessingGraphOptions
|
||||||
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
|
// tensors_in: (std::vector<mediapipe::Tensor>) tensors to postprocess.
|
||||||
|
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||||
|
// timestamps that should be used to aggregate embedding results.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<Source<EmbeddingResult>> BuildEmbeddingPostprocessing(
|
absl::StatusOr<EmbeddingPostprocessingOutputStreams>
|
||||||
|
BuildEmbeddingPostprocessing(
|
||||||
const proto::EmbeddingPostprocessingGraphOptions options,
|
const proto::EmbeddingPostprocessingGraphOptions options,
|
||||||
Source<std::vector<Tensor>> tensors_in, Graph& graph) {
|
Source<std::vector<Tensor>> tensors_in,
|
||||||
|
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||||
// If output tensors are quantized, they must be dequantized first.
|
// If output tensors are quantized, they must be dequantized first.
|
||||||
TensorsSource dequantized_tensors(&tensors_in);
|
TensorsSource dequantized_tensors(&tensors_in);
|
||||||
if (options.has_quantized_outputs()) {
|
if (options.has_quantized_outputs()) {
|
||||||
|
@ -220,7 +244,20 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
.GetOptions<mediapipe::TensorsToEmbeddingsCalculatorOptions>()
|
||||||
.CopyFrom(options.tensors_to_embeddings_options());
|
.CopyFrom(options.tensors_to_embeddings_options());
|
||||||
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
dequantized_tensors >> tensors_to_embeddings_node.In(kTensorsTag);
|
||||||
return tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)];
|
|
||||||
|
// Adds EmbeddingAggregationCalculator.
|
||||||
|
GenericNode& aggregation_node =
|
||||||
|
graph.AddNode("EmbeddingAggregationCalculator");
|
||||||
|
tensors_to_embeddings_node[Output<EmbeddingResult>(kEmbeddingsTag)] >>
|
||||||
|
aggregation_node.In(kEmbeddingsTag);
|
||||||
|
timestamps_in >> aggregation_node.In(kTimestampsTag);
|
||||||
|
|
||||||
|
// Connects outputs.
|
||||||
|
return EmbeddingPostprocessingOutputStreams{
|
||||||
|
/*embeddings=*/aggregation_node[Output<EmbeddingResult>(
|
||||||
|
kEmbeddingsTag)],
|
||||||
|
/*timestamped_embeddings=*/aggregation_node
|
||||||
|
[Output<std::vector<EmbeddingResult>>(kTimestampedEmbeddingsTag)]};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
REGISTER_MEDIAPIPE_GRAPH(
|
||||||
|
|
|
@ -44,12 +44,20 @@ namespace processors {
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// The output tensors of an InferenceCalculator, to convert into
|
// The output tensors of an InferenceCalculator, to convert into
|
||||||
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
// EmbeddingResult objects. Expected to be of type kFloat32 or kUInt8.
|
||||||
|
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||||
|
// The collection of the timestamps that this calculator should aggregate.
|
||||||
|
// This stream is optional: if provided then the TIMESTAMPED_EMBEDDINGS
|
||||||
|
// output is used for results. Otherwise as no timestamp aggregation is
|
||||||
|
// required the EMBEDDINGS output is used for results.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// EMBEDDINGS - EmbeddingResult
|
// EMBEDDINGS - EmbeddingResult @Optional
|
||||||
// The output EmbeddingResult.
|
// The embedding results aggregated by head. Must be connected if the
|
||||||
//
|
// TIMESTAMPS input is not connected, as it signals that timestamp
|
||||||
// TODO: add support for additional optional "TIMESTAMPS" input for
|
// aggregation is not required.
|
||||||
// embeddings aggregation.
|
// TIMESTAMPED_EMBEDDINGS - std::vector<EmbeddingResult> @Optional
|
||||||
|
// The embedding result aggregated by timestamp, then by head. Must be
|
||||||
|
// connected if the TIMESTAMPS input is connected, as it signals that
|
||||||
|
// timestamp aggregation is required.
|
||||||
absl::Status ConfigureEmbeddingPostprocessing(
|
absl::Status ConfigureEmbeddingPostprocessing(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const proto::EmbedderOptions& embedder_options,
|
const proto::EmbedderOptions& embedder_options,
|
||||||
|
|
|
@ -20,11 +20,20 @@ limitations under the License.
|
||||||
#include "absl/flags/flag.h"
|
#include "absl/flags/flag.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
#include "mediapipe/framework/api2/port.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/deps/file_path.h"
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
|
#include "mediapipe/framework/graph_runner.h"
|
||||||
|
#include "mediapipe/framework/output_stream_poller.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/framework/timestamp.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
@ -37,7 +46,12 @@ namespace components {
|
||||||
namespace processors {
|
namespace processors {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
using ::mediapipe::api2::builder::Graph;
|
||||||
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
|
|
||||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/";
|
||||||
|
@ -51,6 +65,16 @@ constexpr char kQuantizedImageClassifierWithoutMetadata[] =
|
||||||
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
"vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite";
|
||||||
|
|
||||||
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
||||||
|
constexpr int kMobileNetV3EmbedderEmbeddingSize = 1024;
|
||||||
|
|
||||||
|
constexpr char kTensorsTag[] = "TENSORS";
|
||||||
|
constexpr char kTensorsName[] = "tensors";
|
||||||
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
constexpr char kTimestampsName[] = "timestamps";
|
||||||
|
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||||
|
constexpr char kEmbeddingsName[] = "embeddings";
|
||||||
|
constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS";
|
||||||
|
constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings";
|
||||||
|
|
||||||
// Helper function to get ModelResources.
|
// Helper function to get ModelResources.
|
||||||
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel(
|
||||||
|
@ -128,8 +152,171 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) {
|
||||||
has_quantized_outputs: false)pb")));
|
has_quantized_outputs: false)pb")));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: add E2E Postprocessing tests once timestamp aggregation is
|
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
// supported.
|
protected:
|
||||||
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
|
absl::string_view model_name, const proto::EmbedderOptions& options,
|
||||||
|
bool connect_timestamps = false) {
|
||||||
|
ASSIGN_OR_RETURN(auto model_resources,
|
||||||
|
CreateModelResourcesForModel(model_name));
|
||||||
|
|
||||||
|
Graph graph;
|
||||||
|
auto& postprocessing = graph.AddNode(
|
||||||
|
"mediapipe.tasks.components.processors."
|
||||||
|
"EmbeddingPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(ConfigureEmbeddingPostprocessing(
|
||||||
|
*model_resources, options,
|
||||||
|
&postprocessing
|
||||||
|
.GetOptions<proto::EmbeddingPostprocessingGraphOptions>()));
|
||||||
|
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||||
|
postprocessing.In(kTensorsTag);
|
||||||
|
if (connect_timestamps) {
|
||||||
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)].SetName(
|
||||||
|
kTimestampsName) >>
|
||||||
|
postprocessing.In(kTimestampsTag);
|
||||||
|
postprocessing.Out(kTimestampedEmbeddingsTag)
|
||||||
|
.SetName(kTimestampedEmbeddingsName) >>
|
||||||
|
graph[Output<std::vector<EmbeddingResult>>(
|
||||||
|
kTimestampedEmbeddingsTag)];
|
||||||
|
} else {
|
||||||
|
postprocessing.Out(kEmbeddingsTag).SetName(kEmbeddingsName) >>
|
||||||
|
graph[Output<EmbeddingResult>(kEmbeddingsTag)];
|
||||||
|
}
|
||||||
|
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig()));
|
||||||
|
if (connect_timestamps) {
|
||||||
|
ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller(
|
||||||
|
kTimestampedEmbeddingsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
ASSIGN_OR_RETURN(auto poller,
|
||||||
|
calculator_graph_.AddOutputStreamPoller(kEmbeddingsName));
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{}));
|
||||||
|
return poller;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void AddTensor(
|
||||||
|
const std::vector<T>& tensor, const Tensor::ElementType& element_type,
|
||||||
|
const Tensor::QuantizationParameters& quantization_parameters = {}) {
|
||||||
|
tensors_->emplace_back(element_type,
|
||||||
|
Tensor::Shape{1, static_cast<int>(tensor.size())},
|
||||||
|
quantization_parameters);
|
||||||
|
auto view = tensors_->back().GetCpuWriteView();
|
||||||
|
T* buffer = view.buffer<T>();
|
||||||
|
std::copy(tensor.begin(), tensor.end(), buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Run(
|
||||||
|
std::optional<std::vector<int>> aggregation_timestamps = std::nullopt,
|
||||||
|
int timestamp = 0) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp))));
|
||||||
|
// Reset tensors for future calls.
|
||||||
|
tensors_ = absl::make_unique<std::vector<Tensor>>();
|
||||||
|
if (aggregation_timestamps.has_value()) {
|
||||||
|
auto packet = absl::make_unique<std::vector<Timestamp>>();
|
||||||
|
for (const auto& timestamp : *aggregation_timestamps) {
|
||||||
|
packet->emplace_back(Timestamp(timestamp));
|
||||||
|
}
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream(
|
||||||
|
kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp))));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
absl::StatusOr<T> GetResult(OutputStreamPoller& poller) {
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle());
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams());
|
||||||
|
|
||||||
|
Packet packet;
|
||||||
|
if (!poller.Next(&packet)) {
|
||||||
|
return absl::InternalError("Unable to get output packet");
|
||||||
|
}
|
||||||
|
auto result = packet.Get<T>();
|
||||||
|
MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
CalculatorGraph calculator_graph_;
|
||||||
|
std::unique_ptr<std::vector<Tensor>> tensors_ =
|
||||||
|
absl::make_unique<std::vector<Tensor>>();
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithoutTimestamps) {
|
||||||
|
// Build graph.
|
||||||
|
proto::EmbedderOptions options;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller,
|
||||||
|
BuildGraph(kMobileNetV3Embedder, options));
|
||||||
|
// Build input tensor.
|
||||||
|
std::vector<float> tensor(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||||
|
tensor[0] = 1.0;
|
||||||
|
|
||||||
|
// Send tensor and get results.
|
||||||
|
AddTensor(tensor, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results, GetResult<EmbeddingResult>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_FALSE(results.has_timestamp_ms());
|
||||||
|
EXPECT_EQ(results.embeddings_size(), 1);
|
||||||
|
EXPECT_EQ(results.embeddings(0).head_index(), 0);
|
||||||
|
EXPECT_EQ(results.embeddings(0).head_name(), "feature");
|
||||||
|
EXPECT_EQ(results.embeddings(0).float_embedding().values_size(),
|
||||||
|
kMobileNetV3EmbedderEmbeddingSize);
|
||||||
|
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(0), 1.0);
|
||||||
|
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(results.embeddings(0).float_embedding().values(i), 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
|
// Build graph.
|
||||||
|
proto::EmbedderOptions options;
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto poller, BuildGraph(kMobileNetV3Embedder, options,
|
||||||
|
/*connect_timestamps=*/true));
|
||||||
|
// Build input tensors.
|
||||||
|
std::vector<float> tensor_0(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||||
|
tensor_0[0] = 1.0;
|
||||||
|
std::vector<float> tensor_1(kMobileNetV3EmbedderEmbeddingSize, 0);
|
||||||
|
tensor_1[0] = 2.0;
|
||||||
|
|
||||||
|
// Send tensors and get results.
|
||||||
|
AddTensor(tensor_0, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run());
|
||||||
|
AddTensor(tensor_1, Tensor::ElementType::kFloat32);
|
||||||
|
MP_ASSERT_OK(Run(
|
||||||
|
/*aggregation_timestamps=*/std::optional<std::vector<int>>({0, 1000}),
|
||||||
|
/*timestamp=*/1000));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto results,
|
||||||
|
GetResult<std::vector<EmbeddingResult>>(poller));
|
||||||
|
|
||||||
|
// Validate results.
|
||||||
|
EXPECT_EQ(results.size(), 2);
|
||||||
|
// First timestamp.
|
||||||
|
EXPECT_EQ(results[0].timestamp_ms(), 0);
|
||||||
|
EXPECT_EQ(results[0].embeddings(0).head_index(), 0);
|
||||||
|
EXPECT_EQ(results[0].embeddings(0).head_name(), "feature");
|
||||||
|
EXPECT_EQ(results[0].embeddings(0).float_embedding().values_size(),
|
||||||
|
kMobileNetV3EmbedderEmbeddingSize);
|
||||||
|
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(0), 1.0);
|
||||||
|
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(results[0].embeddings(0).float_embedding().values(i), 0.0);
|
||||||
|
}
|
||||||
|
// Second timestamp.
|
||||||
|
EXPECT_EQ(results[1].timestamp_ms(), 1);
|
||||||
|
EXPECT_EQ(results[1].embeddings(0).head_index(), 0);
|
||||||
|
EXPECT_EQ(results[1].embeddings(0).head_name(), "feature");
|
||||||
|
EXPECT_EQ(results[1].embeddings(0).float_embedding().values_size(),
|
||||||
|
kMobileNetV3EmbedderEmbeddingSize);
|
||||||
|
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(0), 2.0);
|
||||||
|
for (int i = 1; i < kMobileNetV3EmbedderEmbeddingSize; ++i) {
|
||||||
|
EXPECT_FLOAT_EQ(results[1].embeddings(0).float_embedding().values(i), 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace processors
|
} // namespace processors
|
||||||
|
|
|
@ -32,7 +32,4 @@ message EmbeddingPostprocessingGraphOptions {
|
||||||
|
|
||||||
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
// Whether output tensors are quantized (kTfLiteUint8) or not (kFloat32).
|
||||||
optional bool has_quantized_outputs = 2;
|
optional bool has_quantized_outputs = 2;
|
||||||
|
|
||||||
// TODO: add options to control whether timestamp aggregation
|
|
||||||
// should be used or not.
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -110,12 +110,22 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "hand_landmarker_result",
|
||||||
|
hdrs = ["hand_landmarker_result.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "hand_landmarker",
|
name = "hand_landmarker",
|
||||||
srcs = ["hand_landmarker.cc"],
|
srcs = ["hand_landmarker.cc"],
|
||||||
hdrs = ["hand_landmarker.h"],
|
hdrs = ["hand_landmarker.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":hand_landmarker_graph",
|
":hand_landmarker_graph",
|
||||||
|
":hand_landmarker_result",
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
|
@ -124,7 +134,6 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
"//mediapipe/tasks/cc/components/containers:hand_landmarks_detection_result",
|
|
||||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
|
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||||
|
@ -34,6 +33,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h"
|
||||||
|
|
||||||
|
@ -47,8 +47,6 @@ namespace {
|
||||||
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
|
using HandLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision::
|
||||||
hand_landmarker::proto::HandLandmarkerGraphOptions;
|
hand_landmarker::proto::HandLandmarkerGraphOptions;
|
||||||
|
|
||||||
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
|
|
||||||
|
|
||||||
constexpr char kHandLandmarkerGraphTypeName[] =
|
constexpr char kHandLandmarkerGraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
|
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
|
||||||
|
|
||||||
|
@ -145,7 +143,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
|
||||||
Packet empty_packet =
|
Packet empty_packet =
|
||||||
status_or_packets.value()[kHandLandmarksStreamName];
|
status_or_packets.value()[kHandLandmarksStreamName];
|
||||||
result_callback(
|
result_callback(
|
||||||
{HandLandmarksDetectionResult()}, image_packet.Get<Image>(),
|
{HandLandmarkerResult()}, image_packet.Get<Image>(),
|
||||||
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -173,7 +171,7 @@ absl::StatusOr<std::unique_ptr<HandLandmarker>> HandLandmarker::Create(
|
||||||
std::move(packets_callback));
|
std::move(packets_callback));
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
|
absl::StatusOr<HandLandmarkerResult> HandLandmarker::Detect(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -192,7 +190,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
|
||||||
{kNormRectStreamName,
|
{kNormRectStreamName,
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||||
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
||||||
return {HandLandmarksDetectionResult()};
|
return {HandLandmarkerResult()};
|
||||||
}
|
}
|
||||||
return {{/* handedness= */
|
return {{/* handedness= */
|
||||||
{output_packets[kHandednessStreamName]
|
{output_packets[kHandednessStreamName]
|
||||||
|
@ -205,7 +203,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::Detect(
|
||||||
.Get<std::vector<mediapipe::LandmarkList>>()}}};
|
.Get<std::vector<mediapipe::LandmarkList>>()}}};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
|
absl::StatusOr<HandLandmarkerResult> HandLandmarker::DetectForVideo(
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
std::optional<core::ImageProcessingOptions> image_processing_options) {
|
||||||
if (image.UsesGpu()) {
|
if (image.UsesGpu()) {
|
||||||
|
@ -227,7 +225,7 @@ absl::StatusOr<HandLandmarksDetectionResult> HandLandmarker::DetectForVideo(
|
||||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||||
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
if (output_packets[kHandLandmarksStreamName].IsEmpty()) {
|
||||||
return {HandLandmarksDetectionResult()};
|
return {HandLandmarkerResult()};
|
||||||
}
|
}
|
||||||
return {
|
return {
|
||||||
{/* handedness= */
|
{/* handedness= */
|
||||||
|
|
|
@ -24,12 +24,12 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
|
@ -70,9 +70,7 @@ struct HandLandmarkerOptions {
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(
|
std::function<void(absl::StatusOr<HandLandmarkerResult>, const Image&, int64)>
|
||||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult>,
|
|
||||||
const Image&, int64)>
|
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -92,7 +90,7 @@ struct HandLandmarkerOptions {
|
||||||
// 'y_center', 'width' and 'height' fields is NOT supported and will
|
// 'y_center', 'width' and 'height' fields is NOT supported and will
|
||||||
// result in an invalid argument error being returned.
|
// result in an invalid argument error being returned.
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// HandLandmarksDetectionResult
|
// HandLandmarkerResult
|
||||||
// - The hand landmarks detection results.
|
// - The hand landmarks detection results.
|
||||||
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
||||||
public:
|
public:
|
||||||
|
@ -129,7 +127,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
// TODO: Describes how the input image will be preprocessed
|
// TODO: Describes how the input image will be preprocessed
|
||||||
// after the yuv support is implemented.
|
// after the yuv support is implemented.
|
||||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult> Detect(
|
absl::StatusOr<HandLandmarkerResult> Detect(
|
||||||
Image image,
|
Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
std::nullopt);
|
std::nullopt);
|
||||||
|
@ -147,10 +145,10 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA. It's required to
|
// The image can be of any size with format RGB or RGBA. It's required to
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
absl::StatusOr<components::containers::HandLandmarksDetectionResult>
|
absl::StatusOr<HandLandmarkerResult> DetectForVideo(
|
||||||
DetectForVideo(Image image, int64 timestamp_ms,
|
Image image, int64 timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions>
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
image_processing_options = std::nullopt);
|
std::nullopt);
|
||||||
|
|
||||||
// Sends live image data to perform hand landmarks detection, and the results
|
// Sends live image data to perform hand landmarks detection, and the results
|
||||||
// will be available via the "result_callback" provided in the
|
// will be available via the "result_callback" provided in the
|
||||||
|
@ -169,7 +167,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// invalid argument error being returned.
|
// invalid argument error being returned.
|
||||||
//
|
//
|
||||||
// The "result_callback" provides
|
// The "result_callback" provides
|
||||||
// - A vector of HandLandmarksDetectionResult, each is the detected results
|
// - A vector of HandLandmarkerResult, each is the detected results
|
||||||
// for a input frame.
|
// for a input frame.
|
||||||
// - The const reference to the corresponding input image that the hand
|
// - The const reference to the corresponding input image that the hand
|
||||||
// landmarker runs on. Note that the const reference to the image will no
|
// landmarker runs on. Note that the const reference to the image will no
|
||||||
|
|
|
@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
|
||||||
|
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace vision {
|
||||||
namespace containers {
|
namespace hand_landmarker {
|
||||||
|
|
||||||
// The hand landmarks detection result from HandLandmarker, where each vector
|
// The hand landmarks detection result from HandLandmarker, where each vector
|
||||||
// element represents a single hand detected in the image.
|
// element represents a single hand detected in the image.
|
||||||
struct HandLandmarksDetectionResult {
|
struct HandLandmarkerResult {
|
||||||
// Classification of handedness.
|
// Classification of handedness.
|
||||||
std::vector<mediapipe::ClassificationList> handedness;
|
std::vector<mediapipe::ClassificationList> handedness;
|
||||||
// Detected hand landmarks in normalized image coordinates.
|
// Detected hand landmarks in normalized image coordinates.
|
||||||
|
@ -35,9 +35,9 @@ struct HandLandmarksDetectionResult {
|
||||||
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
std::vector<mediapipe::LandmarkList> hand_world_landmarks;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace containers
|
} // namespace hand_landmarker
|
||||||
} // namespace components
|
} // namespace vision
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_HAND_LANDMARKS_DETECTION_RESULT_H_
|
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_HAND_LANDMARKER_RESULT_H_
|
|
@ -32,12 +32,12 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/hand_landmarks_detection_result.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
|
@ -50,7 +50,6 @@ namespace {
|
||||||
|
|
||||||
using ::file::Defaults;
|
using ::file::Defaults;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::containers::HandLandmarksDetectionResult;
|
|
||||||
using ::mediapipe::tasks::components::containers::Rect;
|
using ::mediapipe::tasks::components::containers::Rect;
|
||||||
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
using ::mediapipe::tasks::containers::proto::LandmarksDetectionResult;
|
||||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
|
||||||
|
@ -95,9 +94,9 @@ LandmarksDetectionResult GetLandmarksDetectionResult(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
|
HandLandmarkerResult GetExpectedHandLandmarkerResult(
|
||||||
const std::vector<absl::string_view>& landmarks_file_names) {
|
const std::vector<absl::string_view>& landmarks_file_names) {
|
||||||
HandLandmarksDetectionResult expected_results;
|
HandLandmarkerResult expected_results;
|
||||||
for (const auto& file_name : landmarks_file_names) {
|
for (const auto& file_name : landmarks_file_names) {
|
||||||
const auto landmarks_detection_result =
|
const auto landmarks_detection_result =
|
||||||
GetLandmarksDetectionResult(file_name);
|
GetLandmarksDetectionResult(file_name);
|
||||||
|
@ -109,9 +108,9 @@ HandLandmarksDetectionResult GetExpectedHandLandmarksDetectionResult(
|
||||||
return expected_results;
|
return expected_results;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ExpectHandLandmarksDetectionResultsCorrect(
|
void ExpectHandLandmarkerResultsCorrect(
|
||||||
const HandLandmarksDetectionResult& actual_results,
|
const HandLandmarkerResult& actual_results,
|
||||||
const HandLandmarksDetectionResult& expected_results) {
|
const HandLandmarkerResult& expected_results) {
|
||||||
const auto& actual_landmarks = actual_results.hand_landmarks;
|
const auto& actual_landmarks = actual_results.hand_landmarks;
|
||||||
const auto& actual_handedness = actual_results.handedness;
|
const auto& actual_handedness = actual_results.handedness;
|
||||||
|
|
||||||
|
@ -145,7 +144,7 @@ struct TestParams {
|
||||||
// clockwise.
|
// clockwise.
|
||||||
int rotation;
|
int rotation;
|
||||||
// Expected results from the hand landmarker model output.
|
// Expected results from the hand landmarker model output.
|
||||||
HandLandmarksDetectionResult expected_results;
|
HandLandmarkerResult expected_results;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ImageModeTest : public testing::TestWithParam<TestParams> {};
|
class ImageModeTest : public testing::TestWithParam<TestParams> {};
|
||||||
|
@ -213,7 +212,7 @@ TEST_P(ImageModeTest, Succeeds) {
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||||
HandLandmarker::Create(std::move(options)));
|
HandLandmarker::Create(std::move(options)));
|
||||||
HandLandmarksDetectionResult hand_landmarker_results;
|
HandLandmarkerResult hand_landmarker_results;
|
||||||
if (GetParam().rotation != 0) {
|
if (GetParam().rotation != 0) {
|
||||||
ImageProcessingOptions image_processing_options;
|
ImageProcessingOptions image_processing_options;
|
||||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||||
|
@ -224,8 +223,8 @@ TEST_P(ImageModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
||||||
hand_landmarker->Detect(image));
|
hand_landmarker->Detect(image));
|
||||||
}
|
}
|
||||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
|
ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
|
||||||
GetParam().expected_results);
|
GetParam().expected_results);
|
||||||
MP_ASSERT_OK(hand_landmarker->Close());
|
MP_ASSERT_OK(hand_landmarker->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,8 +236,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ 0,
|
/* rotation= */ 0,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
|
||||||
{kThumbUpLandmarksFilename}),
|
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
/* test_name= */ "LandmarksPointingUp",
|
/* test_name= */ "LandmarksPointingUp",
|
||||||
|
@ -246,8 +244,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ 0,
|
/* rotation= */ 0,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
|
||||||
{kPointingUpLandmarksFilename}),
|
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
/* test_name= */ "LandmarksPointingUpRotated",
|
/* test_name= */ "LandmarksPointingUpRotated",
|
||||||
|
@ -255,7 +252,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ -90,
|
/* rotation= */ -90,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult(
|
||||||
{kPointingUpRotatedLandmarksFilename}),
|
{kPointingUpRotatedLandmarksFilename}),
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
|
@ -315,7 +312,7 @@ TEST_P(VideoModeTest, Succeeds) {
|
||||||
HandLandmarker::Create(std::move(options)));
|
HandLandmarker::Create(std::move(options)));
|
||||||
const auto expected_results = GetParam().expected_results;
|
const auto expected_results = GetParam().expected_results;
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
HandLandmarksDetectionResult hand_landmarker_results;
|
HandLandmarkerResult hand_landmarker_results;
|
||||||
if (GetParam().rotation != 0) {
|
if (GetParam().rotation != 0) {
|
||||||
ImageProcessingOptions image_processing_options;
|
ImageProcessingOptions image_processing_options;
|
||||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||||
|
@ -326,8 +323,8 @@ TEST_P(VideoModeTest, Succeeds) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
MP_ASSERT_OK_AND_ASSIGN(hand_landmarker_results,
|
||||||
hand_landmarker->DetectForVideo(image, i));
|
hand_landmarker->DetectForVideo(image, i));
|
||||||
}
|
}
|
||||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results,
|
ExpectHandLandmarkerResultsCorrect(hand_landmarker_results,
|
||||||
expected_results);
|
expected_results);
|
||||||
}
|
}
|
||||||
MP_ASSERT_OK(hand_landmarker->Close());
|
MP_ASSERT_OK(hand_landmarker->Close());
|
||||||
}
|
}
|
||||||
|
@ -340,8 +337,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ 0,
|
/* rotation= */ 0,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
|
||||||
{kThumbUpLandmarksFilename}),
|
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
/* test_name= */ "LandmarksPointingUp",
|
/* test_name= */ "LandmarksPointingUp",
|
||||||
|
@ -349,8 +345,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ 0,
|
/* rotation= */ 0,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
|
||||||
{kPointingUpLandmarksFilename}),
|
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
/* test_name= */ "LandmarksPointingUpRotated",
|
/* test_name= */ "LandmarksPointingUpRotated",
|
||||||
|
@ -358,7 +353,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ -90,
|
/* rotation= */ -90,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult(
|
||||||
{kPointingUpRotatedLandmarksFilename}),
|
{kPointingUpRotatedLandmarksFilename}),
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
|
@ -383,9 +378,8 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
|
JoinPath("./", kTestDataDirectory, kHandLandmarkerBundleAsset);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
options->result_callback =
|
options->result_callback = [](absl::StatusOr<HandLandmarkerResult> results,
|
||||||
[](absl::StatusOr<HandLandmarksDetectionResult> results,
|
const Image& image, int64 timestamp_ms) {};
|
||||||
const Image& image, int64 timestamp_ms) {};
|
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||||
HandLandmarker::Create(std::move(options)));
|
HandLandmarker::Create(std::move(options)));
|
||||||
|
@ -416,23 +410,23 @@ TEST_P(LiveStreamModeTest, Succeeds) {
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
|
JoinPath("./", kTestDataDirectory, GetParam().test_model_file);
|
||||||
options->running_mode = core::RunningMode::LIVE_STREAM;
|
options->running_mode = core::RunningMode::LIVE_STREAM;
|
||||||
std::vector<HandLandmarksDetectionResult> hand_landmarker_results;
|
std::vector<HandLandmarkerResult> hand_landmarker_results;
|
||||||
std::vector<std::pair<int, int>> image_sizes;
|
std::vector<std::pair<int, int>> image_sizes;
|
||||||
std::vector<int64> timestamps;
|
std::vector<int64> timestamps;
|
||||||
options->result_callback =
|
options->result_callback = [&hand_landmarker_results, &image_sizes,
|
||||||
[&hand_landmarker_results, &image_sizes, ×tamps](
|
×tamps](
|
||||||
absl::StatusOr<HandLandmarksDetectionResult> results,
|
absl::StatusOr<HandLandmarkerResult> results,
|
||||||
const Image& image, int64 timestamp_ms) {
|
const Image& image, int64 timestamp_ms) {
|
||||||
MP_ASSERT_OK(results.status());
|
MP_ASSERT_OK(results.status());
|
||||||
hand_landmarker_results.push_back(std::move(results.value()));
|
hand_landmarker_results.push_back(std::move(results.value()));
|
||||||
image_sizes.push_back({image.width(), image.height()});
|
image_sizes.push_back({image.width(), image.height()});
|
||||||
timestamps.push_back(timestamp_ms);
|
timestamps.push_back(timestamp_ms);
|
||||||
};
|
};
|
||||||
|
|
||||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HandLandmarker> hand_landmarker,
|
||||||
HandLandmarker::Create(std::move(options)));
|
HandLandmarker::Create(std::move(options)));
|
||||||
for (int i = 0; i < iterations; ++i) {
|
for (int i = 0; i < iterations; ++i) {
|
||||||
HandLandmarksDetectionResult hand_landmarker_results;
|
HandLandmarkerResult hand_landmarker_results;
|
||||||
if (GetParam().rotation != 0) {
|
if (GetParam().rotation != 0) {
|
||||||
ImageProcessingOptions image_processing_options;
|
ImageProcessingOptions image_processing_options;
|
||||||
image_processing_options.rotation_degrees = GetParam().rotation;
|
image_processing_options.rotation_degrees = GetParam().rotation;
|
||||||
|
@ -450,8 +444,8 @@ TEST_P(LiveStreamModeTest, Succeeds) {
|
||||||
|
|
||||||
const auto expected_results = GetParam().expected_results;
|
const auto expected_results = GetParam().expected_results;
|
||||||
for (int i = 0; i < hand_landmarker_results.size(); ++i) {
|
for (int i = 0; i < hand_landmarker_results.size(); ++i) {
|
||||||
ExpectHandLandmarksDetectionResultsCorrect(hand_landmarker_results[i],
|
ExpectHandLandmarkerResultsCorrect(hand_landmarker_results[i],
|
||||||
expected_results);
|
expected_results);
|
||||||
}
|
}
|
||||||
for (const auto& image_size : image_sizes) {
|
for (const auto& image_size : image_sizes) {
|
||||||
EXPECT_EQ(image_size.first, image.width());
|
EXPECT_EQ(image_size.first, image.width());
|
||||||
|
@ -472,8 +466,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ 0,
|
/* rotation= */ 0,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult({kThumbUpLandmarksFilename}),
|
||||||
{kThumbUpLandmarksFilename}),
|
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
/* test_name= */ "LandmarksPointingUp",
|
/* test_name= */ "LandmarksPointingUp",
|
||||||
|
@ -481,8 +474,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ 0,
|
/* rotation= */ 0,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult({kPointingUpLandmarksFilename}),
|
||||||
{kPointingUpLandmarksFilename}),
|
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
/* test_name= */ "LandmarksPointingUpRotated",
|
/* test_name= */ "LandmarksPointingUpRotated",
|
||||||
|
@ -490,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
/* test_model_file= */ kHandLandmarkerBundleAsset,
|
||||||
/* rotation= */ -90,
|
/* rotation= */ -90,
|
||||||
/* expected_results = */
|
/* expected_results = */
|
||||||
GetExpectedHandLandmarksDetectionResult(
|
GetExpectedHandLandmarkerResult(
|
||||||
{kPointingUpRotatedLandmarksFilename}),
|
{kPointingUpRotatedLandmarksFilename}),
|
||||||
},
|
},
|
||||||
TestParams{
|
TestParams{
|
||||||
|
|
|
@ -142,6 +142,36 @@ android_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
android_library(
|
||||||
|
name = "handlandmarker",
|
||||||
|
srcs = [
|
||||||
|
"handlandmarker/HandLandmarker.java",
|
||||||
|
"handlandmarker/HandLandmarkerResult.java",
|
||||||
|
],
|
||||||
|
javacopts = [
|
||||||
|
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||||
|
],
|
||||||
|
manifest = "handlandmarker/AndroidManifest.xml",
|
||||||
|
deps = [
|
||||||
|
":core",
|
||||||
|
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||||
|
"//mediapipe/framework/formats:classification_java_proto_lite",
|
||||||
|
"//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark",
|
||||||
|
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||||
|
"//third_party:autovalue",
|
||||||
|
"@maven//:com_google_guava_guava",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
|
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_vision_aar")
|
||||||
|
|
||||||
mediapipe_tasks_vision_aar(
|
mediapipe_tasks_vision_aar(
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.vision.handlandmarker">
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,501 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||||
|
|
||||||
|
import android.content.Context;
|
||||||
|
import android.os.ParcelFileDescriptor;
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||||
|
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||||
|
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||||
|
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||||
|
import com.google.mediapipe.framework.Packet;
|
||||||
|
import com.google.mediapipe.framework.PacketGetter;
|
||||||
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||||
|
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||||
|
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
|
import com.google.mediapipe.tasks.vision.handdetector.proto.HandDetectorGraphOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarkerGraphOptionsProto;
|
||||||
|
import com.google.mediapipe.tasks.vision.handlandmarker.proto.HandLandmarksDetectorGraphOptionsProto;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs hand landmarks detection on images.
|
||||||
|
*
|
||||||
|
* <p>This API expects a pre-trained hand landmarks model asset bundle. See <TODO link
|
||||||
|
* to the DevSite documentation page>.
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Input image {@link MPImage}
|
||||||
|
* <ul>
|
||||||
|
* <li>The image that hand landmarks detection runs on.
|
||||||
|
* </ul>
|
||||||
|
* <li>Output HandLandmarkerResult {@link HandLandmarkerResult}
|
||||||
|
* <ul>
|
||||||
|
* <li>A HandLandmarkerResult containing hand landmarks.
|
||||||
|
* </ul>
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public final class HandLandmarker extends BaseVisionTaskApi {
|
||||||
|
private static final String TAG = HandLandmarker.class.getSimpleName();
|
||||||
|
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||||
|
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||||
|
private static final List<String> INPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
|
private static final List<String> OUTPUT_STREAMS =
|
||||||
|
Collections.unmodifiableList(
|
||||||
|
Arrays.asList(
|
||||||
|
"LANDMARKS:hand_landmarks",
|
||||||
|
"WORLD_LANDMARKS:world_hand_landmarks",
|
||||||
|
"HANDEDNESS:handedness",
|
||||||
|
"IMAGE:image_out"));
|
||||||
|
private static final int LANDMARKS_OUT_STREAM_INDEX = 0;
|
||||||
|
private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1;
|
||||||
|
private static final int HANDEDNESS_OUT_STREAM_INDEX = 2;
|
||||||
|
private static final int IMAGE_OUT_STREAM_INDEX = 3;
|
||||||
|
private static final String TASK_GRAPH_NAME =
|
||||||
|
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link HandLandmarker} instance from a model file and the default {@link
|
||||||
|
* HandLandmarkerOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelPath path to the hand landmarks model with metadata in the assets.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||||
|
*/
|
||||||
|
public static HandLandmarker createFromFile(Context context, String modelPath) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link HandLandmarker} instance from a model file and the default {@link
|
||||||
|
* HandLandmarkerOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelFile the hand landmarks model {@link File} instance.
|
||||||
|
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||||
|
*/
|
||||||
|
public static HandLandmarker createFromFile(Context context, File modelFile) throws IOException {
|
||||||
|
try (ParcelFileDescriptor descriptor =
|
||||||
|
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||||
|
BaseOptions baseOptions =
|
||||||
|
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link HandLandmarker} instance from a model buffer and the default {@link
|
||||||
|
* HandLandmarkerOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
|
||||||
|
* model.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||||
|
*/
|
||||||
|
public static HandLandmarker createFromBuffer(Context context, final ByteBuffer modelBuffer) {
|
||||||
|
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
|
||||||
|
return createFromOptions(
|
||||||
|
context, HandLandmarkerOptions.builder().setBaseOptions(baseOptions).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link HandLandmarker} instance from a {@link HandLandmarkerOptions}.
|
||||||
|
*
|
||||||
|
* @param context an Android {@link Context}.
|
||||||
|
* @param landmarkerOptions a {@link HandLandmarkerOptions} instance.
|
||||||
|
* @throws MediaPipeException if there is an error during {@link HandLandmarker} creation.
|
||||||
|
*/
|
||||||
|
public static HandLandmarker createFromOptions(
|
||||||
|
Context context, HandLandmarkerOptions landmarkerOptions) {
|
||||||
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
|
OutputHandler<HandLandmarkerResult, MPImage> handler = new OutputHandler<>();
|
||||||
|
handler.setOutputPacketConverter(
|
||||||
|
new OutputHandler.OutputPacketConverter<HandLandmarkerResult, MPImage>() {
|
||||||
|
@Override
|
||||||
|
public HandLandmarkerResult convertToTaskResult(List<Packet> packets) {
|
||||||
|
// If there is no hands detected in the image, just returns empty lists.
|
||||||
|
if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
|
return HandLandmarkerResult.create(
|
||||||
|
new ArrayList<>(),
|
||||||
|
new ArrayList<>(),
|
||||||
|
new ArrayList<>(),
|
||||||
|
packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
|
}
|
||||||
|
return HandLandmarkerResult.create(
|
||||||
|
PacketGetter.getProtoVector(
|
||||||
|
packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()),
|
||||||
|
PacketGetter.getProtoVector(
|
||||||
|
packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()),
|
||||||
|
PacketGetter.getProtoVector(
|
||||||
|
packets.get(HANDEDNESS_OUT_STREAM_INDEX), ClassificationList.parser()),
|
||||||
|
packets.get(LANDMARKS_OUT_STREAM_INDEX).getTimestamp());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||||
|
return new BitmapImageBuilder(
|
||||||
|
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
landmarkerOptions.resultListener().ifPresent(handler::setResultListener);
|
||||||
|
landmarkerOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||||
|
TaskRunner runner =
|
||||||
|
TaskRunner.create(
|
||||||
|
context,
|
||||||
|
TaskInfo.<HandLandmarkerOptions>builder()
|
||||||
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
|
.setInputStreams(INPUT_STREAMS)
|
||||||
|
.setOutputStreams(OUTPUT_STREAMS)
|
||||||
|
.setTaskOptions(landmarkerOptions)
|
||||||
|
.setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||||
|
.build(),
|
||||||
|
handler);
|
||||||
|
return new HandLandmarker(runner, landmarkerOptions.runningMode());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructor to initialize an {@link HandLandmarker} from a {@link TaskRunner} and a {@link
|
||||||
|
* RunningMode}.
|
||||||
|
*
|
||||||
|
* @param taskRunner a {@link TaskRunner}.
|
||||||
|
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||||
|
*/
|
||||||
|
private HandLandmarker(TaskRunner taskRunner, RunningMode runningMode) {
|
||||||
|
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs hand landmarks detection on the provided single image with default image processing
|
||||||
|
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||||
|
* HandLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java doc
|
||||||
|
* for input image format.
|
||||||
|
*
|
||||||
|
* <p>{@link HandLandmarker} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public HandLandmarkerResult detect(MPImage image) {
|
||||||
|
return detect(image, ImageProcessingOptions.builder().build());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs hand landmarks detection on the provided single image. Only use this method when the
|
||||||
|
* {@link HandLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java
|
||||||
|
* doc for input image format.
|
||||||
|
*
|
||||||
|
* <p>{@link HandLandmarker} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public HandLandmarkerResult detect(
|
||||||
|
MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||||
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
return (HandLandmarkerResult) processImageData(image, imageProcessingOptions);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs hand landmarks detection on the provided video frame with default image processing
|
||||||
|
* options, i.e. without any rotation applied. Only use this method when the {@link
|
||||||
|
* HandLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link HandLandmarker} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public HandLandmarkerResult detectForVideo(MPImage image, long timestampMs) {
|
||||||
|
return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs hand landmarks detection on the provided video frame. Only use this method when the
|
||||||
|
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
|
* must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link HandLandmarker} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public HandLandmarkerResult detectForVideo(
|
||||||
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
return (HandLandmarkerResult)
|
||||||
|
processVideoData(image, imageProcessingOptions, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends live image data to perform hand landmarks detection with default image processing
|
||||||
|
* options, i.e. without any rotation applied, and the results will be available via the {@link
|
||||||
|
* ResultListener} provided in the {@link HandLandmarkerOptions}. Only use this method when the
|
||||||
|
* {@link HandLandmarker } is created with {@link RunningMode.LIVE_STREAM}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
|
* sent to the hand landmarker. The input timestamps must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link HandLandmarker} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void detectAsync(MPImage image, long timestampMs) {
|
||||||
|
detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends live image data to perform hand landmarks detection, and the results will be available
|
||||||
|
* via the {@link ResultListener} provided in the {@link HandLandmarkerOptions}. Only use this
|
||||||
|
* method when the {@link HandLandmarker} is created with {@link RunningMode.LIVE_STREAM}.
|
||||||
|
*
|
||||||
|
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||||
|
* sent to the hand landmarker. The input timestamps must be monotonically increasing.
|
||||||
|
*
|
||||||
|
* <p>{@link HandLandmarker} supports the following color space types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||||
|
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||||
|
* this method throwing an IllegalArgumentException.
|
||||||
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
* region-of-interest.
|
||||||
|
* @throws MediaPipeException if there is an internal error.
|
||||||
|
*/
|
||||||
|
public void detectAsync(
|
||||||
|
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||||
|
validateImageProcessingOptions(imageProcessingOptions);
|
||||||
|
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Options for setting up an {@link HandLandmarker}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract static class HandLandmarkerOptions extends TaskOptions {
|
||||||
|
|
||||||
|
/** Builder for {@link HandLandmarkerOptions}. */
|
||||||
|
@AutoValue.Builder
|
||||||
|
public abstract static class Builder {
|
||||||
|
/** Sets the base options for the hand landmarker task. */
|
||||||
|
public abstract Builder setBaseOptions(BaseOptions value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the running mode for the hand landmarker task. Default to the image mode. Hand
|
||||||
|
* landmarker has three modes:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>IMAGE: The mode for detecting hand landmarks on single image inputs.
|
||||||
|
* <li>VIDEO: The mode for detecting hand landmarks on the decoded frames of a video.
|
||||||
|
* <li>LIVE_STREAM: The mode for for detecting hand landmarks on a live stream of input
|
||||||
|
* data, such as from camera. In this mode, {@code setResultListener} must be called to
|
||||||
|
* set up a listener to receive the detection results asynchronously.
|
||||||
|
* </ul>
|
||||||
|
*/
|
||||||
|
public abstract Builder setRunningMode(RunningMode value);
|
||||||
|
|
||||||
|
/** Sets the maximum number of hands can be detected by the HandLandmarker. */
|
||||||
|
public abstract Builder setNumHands(Integer value);
|
||||||
|
|
||||||
|
/** Sets minimum confidence score for the hand detection to be considered successful */
|
||||||
|
public abstract Builder setMinHandDetectionConfidence(Float value);
|
||||||
|
|
||||||
|
/** Sets minimum confidence score of hand presence score in the hand landmark detection. */
|
||||||
|
public abstract Builder setMinHandPresenceConfidence(Float value);
|
||||||
|
|
||||||
|
/** Sets the minimum confidence score for the hand tracking to be considered successful. */
|
||||||
|
public abstract Builder setMinTrackingConfidence(Float value);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the result listener to receive the detection results asynchronously when the hand
|
||||||
|
* landmarker is in the live stream mode.
|
||||||
|
*/
|
||||||
|
public abstract Builder setResultListener(
|
||||||
|
ResultListener<HandLandmarkerResult, MPImage> value);
|
||||||
|
|
||||||
|
/** Sets an optional error listener. */
|
||||||
|
public abstract Builder setErrorListener(ErrorListener value);
|
||||||
|
|
||||||
|
abstract HandLandmarkerOptions autoBuild();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates and builds the {@link HandLandmarkerOptions} instance.
|
||||||
|
*
|
||||||
|
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||||
|
* properly configured. The result listener should only be set when the hand landmarker is
|
||||||
|
* in the live stream mode.
|
||||||
|
*/
|
||||||
|
public final HandLandmarkerOptions build() {
|
||||||
|
HandLandmarkerOptions options = autoBuild();
|
||||||
|
if (options.runningMode() == RunningMode.LIVE_STREAM) {
|
||||||
|
if (!options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The hand landmarker is in the live stream mode, a user-defined result listener"
|
||||||
|
+ " must be provided in HandLandmarkerOptions.");
|
||||||
|
}
|
||||||
|
} else if (options.resultListener().isPresent()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"The hand landmarker is in the image or the video mode, a user-defined result"
|
||||||
|
+ " listener shouldn't be provided in HandLandmarkerOptions.");
|
||||||
|
}
|
||||||
|
return options;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
|
abstract RunningMode runningMode();
|
||||||
|
|
||||||
|
abstract Optional<Integer> numHands();
|
||||||
|
|
||||||
|
abstract Optional<Float> minHandDetectionConfidence();
|
||||||
|
|
||||||
|
abstract Optional<Float> minHandPresenceConfidence();
|
||||||
|
|
||||||
|
abstract Optional<Float> minTrackingConfidence();
|
||||||
|
|
||||||
|
abstract Optional<ResultListener<HandLandmarkerResult, MPImage>> resultListener();
|
||||||
|
|
||||||
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
|
public static Builder builder() {
|
||||||
|
return new AutoValue_HandLandmarker_HandLandmarkerOptions.Builder()
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.setNumHands(1)
|
||||||
|
.setMinHandDetectionConfidence(0.5f)
|
||||||
|
.setMinHandPresenceConfidence(0.5f)
|
||||||
|
.setMinTrackingConfidence(0.5f);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Converts a {@link HandLandmarkerOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||||
|
@Override
|
||||||
|
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||||
|
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.Builder taskOptionsBuilder =
|
||||||
|
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptionsProto.BaseOptions.newBuilder()
|
||||||
|
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
||||||
|
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
|
||||||
|
.build());
|
||||||
|
|
||||||
|
// Setup HandDetectorGraphOptions.
|
||||||
|
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.Builder
|
||||||
|
handDetectorGraphOptionsBuilder =
|
||||||
|
HandDetectorGraphOptionsProto.HandDetectorGraphOptions.newBuilder();
|
||||||
|
numHands().ifPresent(handDetectorGraphOptionsBuilder::setNumHands);
|
||||||
|
minHandDetectionConfidence()
|
||||||
|
.ifPresent(handDetectorGraphOptionsBuilder::setMinDetectionConfidence);
|
||||||
|
|
||||||
|
// Setup HandLandmarkerGraphOptions.
|
||||||
|
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.Builder
|
||||||
|
handLandmarksDetectorGraphOptionsBuilder =
|
||||||
|
HandLandmarksDetectorGraphOptionsProto.HandLandmarksDetectorGraphOptions.newBuilder();
|
||||||
|
minHandPresenceConfidence()
|
||||||
|
.ifPresent(handLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence);
|
||||||
|
minTrackingConfidence().ifPresent(taskOptionsBuilder::setMinTrackingConfidence);
|
||||||
|
|
||||||
|
taskOptionsBuilder
|
||||||
|
.setHandDetectorGraphOptions(handDetectorGraphOptionsBuilder.build())
|
||||||
|
.setHandLandmarksDetectorGraphOptions(handLandmarksDetectorGraphOptionsBuilder.build());
|
||||||
|
|
||||||
|
return CalculatorOptions.newBuilder()
|
||||||
|
.setExtension(
|
||||||
|
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.ext,
|
||||||
|
taskOptionsBuilder.build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||||
|
* region-of-interest.
|
||||||
|
*/
|
||||||
|
private static void validateImageProcessingOptions(
|
||||||
|
ImageProcessingOptions imageProcessingOptions) {
|
||||||
|
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||||
|
throw new IllegalArgumentException("HandLandmarker doesn't support region-of-interest.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,109 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||||
|
|
||||||
|
import com.google.auto.value.AutoValue;
|
||||||
|
import com.google.mediapipe.formats.proto.LandmarkProto.Landmark;
|
||||||
|
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||||
|
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmark;
|
||||||
|
import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto.Classification;
|
||||||
|
import com.google.mediapipe.formats.proto.ClassificationProto.ClassificationList;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.core.TaskResult;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/** Represents the hand landmarks deection results generated by {@link HandLandmarker}. */
|
||||||
|
@AutoValue
|
||||||
|
public abstract class HandLandmarkerResult implements TaskResult {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a {@link HandLandmarkerResult} instance from the lists of landmarks and
|
||||||
|
* handedness protobuf messages.
|
||||||
|
*
|
||||||
|
* @param landmarksProto a List of {@link NormalizedLandmarkList}
|
||||||
|
* @param worldLandmarksProto a List of {@link LandmarkList}
|
||||||
|
* @param handednessesProto a List of {@link ClassificationList}
|
||||||
|
*/
|
||||||
|
static HandLandmarkerResult create(
|
||||||
|
List<NormalizedLandmarkList> landmarksProto,
|
||||||
|
List<LandmarkList> worldLandmarksProto,
|
||||||
|
List<ClassificationList> handednessesProto,
|
||||||
|
long timestampMs) {
|
||||||
|
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandLandmarks =
|
||||||
|
new ArrayList<>();
|
||||||
|
List<List<com.google.mediapipe.tasks.components.containers.Landmark>> multiHandWorldLandmarks =
|
||||||
|
new ArrayList<>();
|
||||||
|
List<List<Category>> multiHandHandednesses = new ArrayList<>();
|
||||||
|
for (NormalizedLandmarkList handLandmarksProto : landmarksProto) {
|
||||||
|
List<com.google.mediapipe.tasks.components.containers.Landmark> handLandmarks =
|
||||||
|
new ArrayList<>();
|
||||||
|
multiHandLandmarks.add(handLandmarks);
|
||||||
|
for (NormalizedLandmark handLandmarkProto : handLandmarksProto.getLandmarkList()) {
|
||||||
|
handLandmarks.add(
|
||||||
|
com.google.mediapipe.tasks.components.containers.Landmark.create(
|
||||||
|
handLandmarkProto.getX(),
|
||||||
|
handLandmarkProto.getY(),
|
||||||
|
handLandmarkProto.getZ(),
|
||||||
|
true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (LandmarkList handWorldLandmarksProto : worldLandmarksProto) {
|
||||||
|
List<com.google.mediapipe.tasks.components.containers.Landmark> handWorldLandmarks =
|
||||||
|
new ArrayList<>();
|
||||||
|
multiHandWorldLandmarks.add(handWorldLandmarks);
|
||||||
|
for (Landmark handWorldLandmarkProto : handWorldLandmarksProto.getLandmarkList()) {
|
||||||
|
handWorldLandmarks.add(
|
||||||
|
com.google.mediapipe.tasks.components.containers.Landmark.create(
|
||||||
|
handWorldLandmarkProto.getX(),
|
||||||
|
handWorldLandmarkProto.getY(),
|
||||||
|
handWorldLandmarkProto.getZ(),
|
||||||
|
false));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (ClassificationList handednessProto : handednessesProto) {
|
||||||
|
List<Category> handedness = new ArrayList<>();
|
||||||
|
multiHandHandednesses.add(handedness);
|
||||||
|
for (Classification classification : handednessProto.getClassificationList()) {
|
||||||
|
handedness.add(
|
||||||
|
Category.create(
|
||||||
|
classification.getScore(),
|
||||||
|
classification.getIndex(),
|
||||||
|
classification.getLabel(),
|
||||||
|
classification.getDisplayName()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new AutoValue_HandLandmarkerResult(
|
||||||
|
timestampMs,
|
||||||
|
Collections.unmodifiableList(multiHandLandmarks),
|
||||||
|
Collections.unmodifiableList(multiHandWorldLandmarks),
|
||||||
|
Collections.unmodifiableList(multiHandHandednesses));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract long timestampMs();
|
||||||
|
|
||||||
|
/** Hand landmarks of detected hands. */
|
||||||
|
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>> landmarks();
|
||||||
|
|
||||||
|
/** Hand landmarks in world coordniates of detected hands. */
|
||||||
|
public abstract List<List<com.google.mediapipe.tasks.components.containers.Landmark>>
|
||||||
|
worldLandmarks();
|
||||||
|
|
||||||
|
/** Handedness of detected hands. */
|
||||||
|
public abstract List<List<Category>> handednesses();
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
package="com.google.mediapipe.tasks.vision.handlandmarkertest"
|
||||||
|
android:versionCode="1"
|
||||||
|
android:versionName="1.0" >
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||||
|
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||||
|
|
||||||
|
<uses-sdk android:minSdkVersion="24"
|
||||||
|
android:targetSdkVersion="30" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:label="handlandmarkertest"
|
||||||
|
android:name="android.support.multidex.MultiDexApplication"
|
||||||
|
android:taskAffinity="">
|
||||||
|
<uses-library android:name="android.test.runner" />
|
||||||
|
</application>
|
||||||
|
|
||||||
|
<instrumentation
|
||||||
|
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||||
|
android:targetPackage="com.google.mediapipe.tasks.vision.handlandmarkertest" />
|
||||||
|
|
||||||
|
</manifest>
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
# TODO: Enable this in OSS
|
|
@ -0,0 +1,424 @@
|
||||||
|
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package com.google.mediapipe.tasks.vision.handlandmarker;
|
||||||
|
|
||||||
|
import static com.google.common.truth.Truth.assertThat;
|
||||||
|
import static org.junit.Assert.assertThrows;
|
||||||
|
|
||||||
|
import android.content.res.AssetManager;
|
||||||
|
import android.graphics.BitmapFactory;
|
||||||
|
import android.graphics.RectF;
|
||||||
|
import androidx.test.core.app.ApplicationProvider;
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||||
|
import com.google.common.truth.Correspondence;
|
||||||
|
import com.google.mediapipe.framework.MediaPipeException;
|
||||||
|
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||||
|
import com.google.mediapipe.framework.image.MPImage;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Category;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.Landmark;
|
||||||
|
import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult;
|
||||||
|
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||||
|
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||||
|
import com.google.mediapipe.tasks.vision.handlandmarker.HandLandmarker.HandLandmarkerOptions;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Suite;
|
||||||
|
import org.junit.runners.Suite.SuiteClasses;
|
||||||
|
|
||||||
|
/** Test for {@link HandLandmarker}. */
|
||||||
|
@RunWith(Suite.class)
|
||||||
|
@SuiteClasses({HandLandmarkerTest.General.class, HandLandmarkerTest.RunningModeTest.class})
|
||||||
|
public class HandLandmarkerTest {
|
||||||
|
private static final String HAND_LANDMARKER_BUNDLE_ASSET_FILE = "hand_landmarker.task";
|
||||||
|
private static final String TWO_HANDS_IMAGE = "right_hands.jpg";
|
||||||
|
private static final String THUMB_UP_IMAGE = "thumb_up.jpg";
|
||||||
|
private static final String POINTING_UP_ROTATED_IMAGE = "pointing_up_rotated.jpg";
|
||||||
|
private static final String NO_HANDS_IMAGE = "cats_and_dogs.jpg";
|
||||||
|
private static final String THUMB_UP_LANDMARKS = "thumb_up_landmarks.pb";
|
||||||
|
private static final String POINTING_UP_ROTATED_LANDMARKS = "pointing_up_rotated_landmarks.pb";
|
||||||
|
private static final String TAG = "Hand Landmarker Test";
|
||||||
|
private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f;
|
||||||
|
private static final int IMAGE_WIDTH = 382;
|
||||||
|
private static final int IMAGE_HEIGHT = 406;
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class General extends HandLandmarkerTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithValidModels() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
HandLandmarkerResult actualResult =
|
||||||
|
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE));
|
||||||
|
HandLandmarkerResult expectedResult =
|
||||||
|
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||||
|
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithEmptyResult() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.build();
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
HandLandmarkerResult actualResult =
|
||||||
|
handLandmarker.detect(getImageFromAsset(NO_HANDS_IMAGE));
|
||||||
|
assertThat(actualResult.landmarks()).isEmpty();
|
||||||
|
assertThat(actualResult.worldLandmarks()).isEmpty();
|
||||||
|
assertThat(actualResult.handednesses()).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void detect_successWithNumHands() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.setNumHands(2)
|
||||||
|
.build();
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
HandLandmarkerResult actualResult =
|
||||||
|
handLandmarker.detect(getImageFromAsset(TWO_HANDS_IMAGE));
|
||||||
|
assertThat(actualResult.handednesses()).hasSize(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_successWithRotation() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.setNumHands(1)
|
||||||
|
.build();
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRotationDegrees(-90).build();
|
||||||
|
HandLandmarkerResult actualResult =
|
||||||
|
handLandmarker.detect(
|
||||||
|
getImageFromAsset(POINTING_UP_ROTATED_IMAGE), imageProcessingOptions);
|
||||||
|
HandLandmarkerResult expectedResult =
|
||||||
|
getExpectedHandLandmarkerResult(POINTING_UP_ROTATED_LANDMARKS);
|
||||||
|
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_failsWithRegionOfInterest() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.setNumHands(1)
|
||||||
|
.build();
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
ImageProcessingOptions imageProcessingOptions =
|
||||||
|
ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build();
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE), imageProcessingOptions));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("HandLandmarker doesn't support region-of-interest");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@RunWith(AndroidJUnit4.class)
|
||||||
|
public static final class RunningModeTest extends HandLandmarkerTest {
|
||||||
|
@Test
|
||||||
|
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
|
||||||
|
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.setRunningMode(mode)
|
||||||
|
.setResultListener((HandLandmarkerResults, inputImage) -> {})
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("a user-defined result listener shouldn't be provided");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
|
||||||
|
IllegalArgumentException exception =
|
||||||
|
assertThrows(
|
||||||
|
IllegalArgumentException.class,
|
||||||
|
() ->
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder()
|
||||||
|
.setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
.build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.build());
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("a user-defined result listener must be provided");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
handLandmarker.detectForVideo(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
handLandmarker.detectAsync(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE)));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
handLandmarker.detectAsync(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener((HandLandmarkerResults, inputImage) -> {})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE)));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||||
|
exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() ->
|
||||||
|
handLandmarker.detectForVideo(
|
||||||
|
getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ 0));
|
||||||
|
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_successWithImageMode() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.IMAGE)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
HandLandmarkerResult actualResult =
|
||||||
|
handLandmarker.detect(getImageFromAsset(THUMB_UP_IMAGE));
|
||||||
|
HandLandmarkerResult expectedResult =
|
||||||
|
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||||
|
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_successWithVideoMode() throws Exception {
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.VIDEO)
|
||||||
|
.build();
|
||||||
|
HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||||
|
HandLandmarkerResult expectedResult =
|
||||||
|
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
HandLandmarkerResult actualResult =
|
||||||
|
handLandmarker.detectForVideo(getImageFromAsset(THUMB_UP_IMAGE), /*timestampsMs=*/ i);
|
||||||
|
assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||||
|
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
|
||||||
|
HandLandmarkerResult expectedResult =
|
||||||
|
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener(
|
||||||
|
(actualResult, inputImage) -> {
|
||||||
|
assertActualResultApproximatelyEqualsToExpectedResult(
|
||||||
|
actualResult, expectedResult);
|
||||||
|
assertImageSizeIsExpected(inputImage);
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
try (HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
|
handLandmarker.detectAsync(image, /*timestampsMs=*/ 1);
|
||||||
|
MediaPipeException exception =
|
||||||
|
assertThrows(
|
||||||
|
MediaPipeException.class,
|
||||||
|
() -> handLandmarker.detectAsync(image, /*timestampsMs=*/ 0));
|
||||||
|
assertThat(exception)
|
||||||
|
.hasMessageThat()
|
||||||
|
.contains("having a smaller timestamp than the processed timestamp");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void recognize_successWithLiveSteamMode() throws Exception {
|
||||||
|
MPImage image = getImageFromAsset(THUMB_UP_IMAGE);
|
||||||
|
HandLandmarkerResult expectedResult =
|
||||||
|
getExpectedHandLandmarkerResult(THUMB_UP_LANDMARKS);
|
||||||
|
HandLandmarkerOptions options =
|
||||||
|
HandLandmarkerOptions.builder()
|
||||||
|
.setBaseOptions(
|
||||||
|
BaseOptions.builder().setModelAssetPath(HAND_LANDMARKER_BUNDLE_ASSET_FILE).build())
|
||||||
|
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||||
|
.setResultListener(
|
||||||
|
(actualResult, inputImage) -> {
|
||||||
|
assertActualResultApproximatelyEqualsToExpectedResult(
|
||||||
|
actualResult, expectedResult);
|
||||||
|
assertImageSizeIsExpected(inputImage);
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
try (HandLandmarker handLandmarker =
|
||||||
|
HandLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
handLandmarker.detectAsync(image, /*timestampsMs=*/ i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||||
|
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||||
|
InputStream istr = assetManager.open(filePath);
|
||||||
|
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static HandLandmarkerResult getExpectedHandLandmarkerResult(
|
||||||
|
String filePath) throws Exception {
|
||||||
|
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||||
|
InputStream istr = assetManager.open(filePath);
|
||||||
|
LandmarksDetectionResult landmarksDetectionResultProto =
|
||||||
|
LandmarksDetectionResult.parser().parseFrom(istr);
|
||||||
|
return HandLandmarkerResult.create(
|
||||||
|
Arrays.asList(landmarksDetectionResultProto.getLandmarks()),
|
||||||
|
Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()),
|
||||||
|
Arrays.asList(landmarksDetectionResultProto.getClassifications()),
|
||||||
|
/*timestampMs=*/ 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertActualResultApproximatelyEqualsToExpectedResult(
|
||||||
|
HandLandmarkerResult actualResult, HandLandmarkerResult expectedResult) {
|
||||||
|
// Expects to have the same number of hands detected.
|
||||||
|
assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size());
|
||||||
|
assertThat(actualResult.worldLandmarks()).hasSize(expectedResult.worldLandmarks().size());
|
||||||
|
assertThat(actualResult.handednesses()).hasSize(expectedResult.handednesses().size());
|
||||||
|
|
||||||
|
// Actual landmarks match expected landmarks.
|
||||||
|
assertThat(actualResult.landmarks().get(0))
|
||||||
|
.comparingElementsUsing(
|
||||||
|
Correspondence.from(
|
||||||
|
(Correspondence.BinaryPredicate<Landmark, Landmark>)
|
||||||
|
(actual, expected) -> {
|
||||||
|
return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
|
||||||
|
.compare(actual.x(), expected.x())
|
||||||
|
&& Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE)
|
||||||
|
.compare(actual.y(), expected.y());
|
||||||
|
},
|
||||||
|
"landmarks approximately equal to"))
|
||||||
|
.containsExactlyElementsIn(expectedResult.landmarks().get(0));
|
||||||
|
|
||||||
|
// Actual handedness matches expected handedness.
|
||||||
|
Category actualTopHandedness = actualResult.handednesses().get(0).get(0);
|
||||||
|
Category expectedTopHandedness = expectedResult.handednesses().get(0).get(0);
|
||||||
|
assertThat(actualTopHandedness.index()).isEqualTo(expectedTopHandedness.index());
|
||||||
|
assertThat(actualTopHandedness.categoryName()).isEqualTo(expectedTopHandedness.categoryName());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void assertImageSizeIsExpected(MPImage inputImage) {
|
||||||
|
assertThat(inputImage).isNotNull();
|
||||||
|
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
|
||||||
|
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
|
||||||
|
}
|
||||||
|
}
|
|
@ -53,7 +53,13 @@ class LandmarksDetectionResult:
|
||||||
def to_pb2(self) -> _LandmarksDetectionResultProto:
|
def to_pb2(self) -> _LandmarksDetectionResultProto:
|
||||||
"""Generates a LandmarksDetectionResult protobuf object."""
|
"""Generates a LandmarksDetectionResult protobuf object."""
|
||||||
|
|
||||||
|
landmarks = _NormalizedLandmarkListProto()
|
||||||
classifications = _ClassificationListProto()
|
classifications = _ClassificationListProto()
|
||||||
|
world_landmarks = _LandmarkListProto()
|
||||||
|
|
||||||
|
for landmark in self.landmarks:
|
||||||
|
landmarks.landmark.append(landmark.to_pb2())
|
||||||
|
|
||||||
for category in self.categories:
|
for category in self.categories:
|
||||||
classifications.classification.append(
|
classifications.classification.append(
|
||||||
_ClassificationProto(
|
_ClassificationProto(
|
||||||
|
@ -63,9 +69,9 @@ class LandmarksDetectionResult:
|
||||||
display_name=category.display_name))
|
display_name=category.display_name))
|
||||||
|
|
||||||
return _LandmarksDetectionResultProto(
|
return _LandmarksDetectionResultProto(
|
||||||
landmarks=_NormalizedLandmarkListProto(self.landmarks),
|
landmarks=landmarks,
|
||||||
classifications=classifications,
|
classifications=classifications,
|
||||||
world_landmarks=_LandmarkListProto(self.world_landmarks),
|
world_landmarks=world_landmarks,
|
||||||
rect=self.rect.to_pb2())
|
rect=self.rect.to_pb2())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -73,9 +79,11 @@ class LandmarksDetectionResult:
|
||||||
def create_from_pb2(
|
def create_from_pb2(
|
||||||
cls,
|
cls,
|
||||||
pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult':
|
pb2_obj: _LandmarksDetectionResultProto) -> 'LandmarksDetectionResult':
|
||||||
"""Creates a `LandmarksDetectionResult` object from the given protobuf object.
|
"""Creates a `LandmarksDetectionResult` object from the given protobuf object."""
|
||||||
"""
|
|
||||||
categories = []
|
categories = []
|
||||||
|
landmarks = []
|
||||||
|
world_landmarks = []
|
||||||
|
|
||||||
for classification in pb2_obj.classifications.classification:
|
for classification in pb2_obj.classifications.classification:
|
||||||
categories.append(
|
categories.append(
|
||||||
category_module.Category(
|
category_module.Category(
|
||||||
|
@ -83,14 +91,14 @@ class LandmarksDetectionResult:
|
||||||
index=classification.index,
|
index=classification.index,
|
||||||
category_name=classification.label,
|
category_name=classification.label,
|
||||||
display_name=classification.display_name))
|
display_name=classification.display_name))
|
||||||
|
|
||||||
|
for landmark in pb2_obj.landmarks.landmark:
|
||||||
|
landmarks.append(_NormalizedLandmark.create_from_pb2(landmark))
|
||||||
|
|
||||||
|
for landmark in pb2_obj.world_landmarks.landmark:
|
||||||
|
world_landmarks.append(_Landmark.create_from_pb2(landmark))
|
||||||
return LandmarksDetectionResult(
|
return LandmarksDetectionResult(
|
||||||
landmarks=[
|
landmarks=landmarks,
|
||||||
_NormalizedLandmark.create_from_pb2(landmark)
|
|
||||||
for landmark in pb2_obj.landmarks.landmark
|
|
||||||
],
|
|
||||||
categories=categories,
|
categories=categories,
|
||||||
world_landmarks=[
|
world_landmarks=world_landmarks,
|
||||||
_Landmark.create_from_pb2(landmark)
|
|
||||||
for landmark in pb2_obj.world_landmarks.landmark
|
|
||||||
],
|
|
||||||
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
rect=_NormalizedRect.create_from_pb2(pb2_obj.rect))
|
||||||
|
|
|
@ -12,9 +12,9 @@ py_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"metadata_info.py",
|
"metadata_info.py",
|
||||||
],
|
],
|
||||||
srcs_version = "PY3",
|
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":writer_utils",
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_py",
|
"//mediapipe/tasks/metadata:metadata_schema_py",
|
||||||
"//mediapipe/tasks/metadata:schema_py",
|
"//mediapipe/tasks/metadata:schema_py",
|
||||||
],
|
],
|
||||||
|
|
|
@ -14,12 +14,14 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Helper classes for common model metadata information."""
|
"""Helper classes for common model metadata information."""
|
||||||
|
|
||||||
|
import collections
|
||||||
import csv
|
import csv
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type, Union
|
||||||
|
|
||||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
|
||||||
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
|
||||||
|
from mediapipe.tasks.python.metadata.metadata_writers import writer_utils
|
||||||
|
|
||||||
# Min and max values for UINT8 tensors.
|
# Min and max values for UINT8 tensors.
|
||||||
_MIN_UINT8 = 0
|
_MIN_UINT8 = 0
|
||||||
|
@ -267,6 +269,86 @@ class RegexTokenizerMd:
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizerMd:
|
||||||
|
"""A container for the Bert tokenizer [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_file_path: str):
|
||||||
|
"""Initializes a BertTokenizerMd object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_file_path: path to the vocabulary file.
|
||||||
|
"""
|
||||||
|
self._vocab_file_path = vocab_file_path
|
||||||
|
|
||||||
|
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||||
|
"""Creates the Bert tokenizer metadata based on the information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Flatbuffers Python object of the Bert tokenizer metadata.
|
||||||
|
"""
|
||||||
|
vocab = _metadata_fb.AssociatedFileT()
|
||||||
|
vocab.name = self._vocab_file_path
|
||||||
|
vocab.description = _VOCAB_FILE_DESCRIPTION
|
||||||
|
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||||
|
tokenizer = _metadata_fb.ProcessUnitT()
|
||||||
|
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
|
||||||
|
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
|
||||||
|
tokenizer.options.vocabFile = [vocab]
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class SentencePieceTokenizerMd:
|
||||||
|
"""A container for the sentence piece tokenizer [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||||
|
"""
|
||||||
|
|
||||||
|
_SP_MODEL_DESCRIPTION = "The sentence piece model file."
|
||||||
|
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
|
||||||
|
" This file is optional during tokenization, while the sentence piece "
|
||||||
|
"model is mandatory.")
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
sentence_piece_model_path: str,
|
||||||
|
vocab_file_path: Optional[str] = None):
|
||||||
|
"""Initializes a SentencePieceTokenizerMd object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence_piece_model_path: path to the sentence piece model file.
|
||||||
|
vocab_file_path: path to the vocabulary file.
|
||||||
|
"""
|
||||||
|
self._sentence_piece_model_path = sentence_piece_model_path
|
||||||
|
self._vocab_file_path = vocab_file_path
|
||||||
|
|
||||||
|
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
|
||||||
|
"""Creates the sentence piece tokenizer metadata based on the information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A Flatbuffers Python object of the sentence piece tokenizer metadata.
|
||||||
|
"""
|
||||||
|
tokenizer = _metadata_fb.ProcessUnitT()
|
||||||
|
tokenizer.optionsType = (
|
||||||
|
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
|
||||||
|
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
|
||||||
|
|
||||||
|
sp_model = _metadata_fb.AssociatedFileT()
|
||||||
|
sp_model.name = self._sentence_piece_model_path
|
||||||
|
sp_model.description = self._SP_MODEL_DESCRIPTION
|
||||||
|
tokenizer.options.sentencePieceModel = [sp_model]
|
||||||
|
if self._vocab_file_path:
|
||||||
|
vocab = _metadata_fb.AssociatedFileT()
|
||||||
|
vocab.name = self._vocab_file_path
|
||||||
|
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
|
||||||
|
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
|
||||||
|
tokenizer.options.vocabFile = [vocab]
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
class TensorMd:
|
class TensorMd:
|
||||||
"""A container for common tensor metadata information.
|
"""A container for common tensor metadata information.
|
||||||
|
|
||||||
|
@ -486,6 +568,145 @@ class InputTextTensorMd(TensorMd):
|
||||||
return tensor_metadata
|
return tensor_metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _get_file_paths(files: List[_metadata_fb.AssociatedFileT]) -> List[str]:
|
||||||
|
"""Gets file paths from a list of associated files."""
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
return [file.name for file in files]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tokenizer_associated_files(
|
||||||
|
tokenizer_options: Optional[
|
||||||
|
Union[_metadata_fb.BertTokenizerOptionsT,
|
||||||
|
_metadata_fb.SentencePieceTokenizerOptionsT]]
|
||||||
|
) -> List[str]:
|
||||||
|
"""Gets a list of associated files packed in the tokenizer_options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer_options: a tokenizer metadata object. Support the following
|
||||||
|
tokenizer types:
|
||||||
|
1. BertTokenizerOptions:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||||
|
2. SentencePieceTokenizerOptions:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of associated files included in tokenizer_options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not tokenizer_options:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if isinstance(tokenizer_options, _metadata_fb.BertTokenizerOptionsT):
|
||||||
|
return _get_file_paths(tokenizer_options.vocabFile)
|
||||||
|
elif isinstance(tokenizer_options,
|
||||||
|
_metadata_fb.SentencePieceTokenizerOptionsT):
|
||||||
|
return _get_file_paths(tokenizer_options.vocabFile) + _get_file_paths(
|
||||||
|
tokenizer_options.sentencePieceModel)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class BertInputTensorsMd:
|
||||||
|
"""A container for the input tensor metadata information of Bert models."""
|
||||||
|
|
||||||
|
_IDS_NAME = "ids"
|
||||||
|
_IDS_DESCRIPTION = "Tokenized ids of the input text."
|
||||||
|
_MASK_NAME = "mask"
|
||||||
|
_MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
|
||||||
|
"tokens.")
|
||||||
|
_SEGMENT_IDS_NAME = "segment_ids"
|
||||||
|
_SEGMENT_IDS_DESCRIPTION = (
|
||||||
|
"0 for the first sequence, 1 for the second sequence if exists.")
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_buffer: bytearray,
|
||||||
|
ids_name: str,
|
||||||
|
mask_name: str,
|
||||||
|
segment_name: str,
|
||||||
|
tokenizer_md: Union[None, BertTokenizerMd,
|
||||||
|
SentencePieceTokenizerMd] = None):
|
||||||
|
"""Initializes a BertInputTensorsMd object.
|
||||||
|
|
||||||
|
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||||
|
in the TFLite schema, which help to determine the tensor order when
|
||||||
|
populating metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_buffer: valid buffer of the model file.
|
||||||
|
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||||
|
the input text.
|
||||||
|
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||||
|
real tokens and `0` for padding tokens.
|
||||||
|
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||||
|
first sequence, and `1` stands for the second sequence if exists.
|
||||||
|
tokenizer_md: information of the tokenizer used to process the input
|
||||||
|
string, if any. Supported tokenizers are: `BertTokenizer` [1] and
|
||||||
|
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
|
||||||
|
refer to `InputTensorsMd`.
|
||||||
|
[1]:
|
||||||
|
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
|
||||||
|
[2]:
|
||||||
|
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
|
||||||
|
[3]:
|
||||||
|
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
|
||||||
|
"""
|
||||||
|
# Verify that tflite_input_names (read from the model) and
|
||||||
|
# input_name (collected from users) are aligned.
|
||||||
|
tflite_input_names = writer_utils.get_input_tensor_names(model_buffer)
|
||||||
|
input_names = [ids_name, mask_name, segment_name]
|
||||||
|
if collections.Counter(tflite_input_names) != collections.Counter(
|
||||||
|
input_names):
|
||||||
|
raise ValueError(
|
||||||
|
f"The input tensor names ({input_names}) do not match the tensor "
|
||||||
|
f"names read from the model ({tflite_input_names}).")
|
||||||
|
|
||||||
|
ids_md = TensorMd(
|
||||||
|
name=self._IDS_NAME,
|
||||||
|
description=self._IDS_DESCRIPTION,
|
||||||
|
tensor_name=ids_name)
|
||||||
|
|
||||||
|
mask_md = TensorMd(
|
||||||
|
name=self._MASK_NAME,
|
||||||
|
description=self._MASK_DESCRIPTION,
|
||||||
|
tensor_name=mask_name)
|
||||||
|
|
||||||
|
segment_ids_md = TensorMd(
|
||||||
|
name=self._SEGMENT_IDS_NAME,
|
||||||
|
description=self._SEGMENT_IDS_DESCRIPTION,
|
||||||
|
tensor_name=segment_name)
|
||||||
|
|
||||||
|
self._input_md = [ids_md, mask_md, segment_ids_md]
|
||||||
|
|
||||||
|
if not isinstance(tokenizer_md,
|
||||||
|
(type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
|
||||||
|
raise ValueError(
|
||||||
|
f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._tokenizer_md = tokenizer_md
|
||||||
|
|
||||||
|
def create_input_process_unit_metadata(
|
||||||
|
self) -> List[_metadata_fb.ProcessUnitT]:
|
||||||
|
"""Creates the input process unit metadata."""
|
||||||
|
if self._tokenizer_md:
|
||||||
|
return [self._tokenizer_md.create_metadata()]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_tokenizer_associated_files(self) -> List[str]:
|
||||||
|
"""Gets the associated files that are packed in the tokenizer."""
|
||||||
|
if self._tokenizer_md:
|
||||||
|
return _get_tokenizer_associated_files(
|
||||||
|
self._tokenizer_md.create_metadata().options)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_md(self) -> List[TensorMd]:
|
||||||
|
return self._input_md
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTensorMd(TensorMd):
|
class ClassificationTensorMd(TensorMd):
|
||||||
"""A container for the classification tensor metadata information.
|
"""A container for the classification tensor metadata information.
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ import csv
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import flatbuffers
|
import flatbuffers
|
||||||
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
from mediapipe.tasks.metadata import metadata_schema_py_generated as metadata_fb
|
||||||
|
@ -101,6 +101,34 @@ class RegexTokenizer:
|
||||||
vocab_file_path: str
|
vocab_file_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BertTokenizer:
|
||||||
|
"""Parameters of the Bert tokenizer [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
vocab_file_path: path to the vocabulary file.
|
||||||
|
"""
|
||||||
|
vocab_file_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class SentencePieceTokenizer:
|
||||||
|
"""Parameters of the sentence piece tokenizer tokenizer [1] metadata information.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
sentence_piece_model_path: path to the sentence piece model file.
|
||||||
|
vocab_file_path: path to the vocabulary file.
|
||||||
|
"""
|
||||||
|
sentence_piece_model_path: str
|
||||||
|
vocab_file_path: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class Labels(object):
|
class Labels(object):
|
||||||
"""Simple container holding classification labels of a particular tensor.
|
"""Simple container holding classification labels of a particular tensor.
|
||||||
|
|
||||||
|
@ -282,7 +310,9 @@ def _create_metadata_buffer(
|
||||||
model_buffer: bytearray,
|
model_buffer: bytearray,
|
||||||
general_md: Optional[metadata_info.GeneralMd] = None,
|
general_md: Optional[metadata_info.GeneralMd] = None,
|
||||||
input_md: Optional[List[metadata_info.TensorMd]] = None,
|
input_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||||
output_md: Optional[List[metadata_info.TensorMd]] = None) -> bytearray:
|
output_md: Optional[List[metadata_info.TensorMd]] = None,
|
||||||
|
input_process_units: Optional[List[metadata_fb.ProcessUnitT]] = None
|
||||||
|
) -> bytearray:
|
||||||
"""Creates a buffer of the metadata.
|
"""Creates a buffer of the metadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -290,7 +320,9 @@ def _create_metadata_buffer(
|
||||||
general_md: general information about the model.
|
general_md: general information about the model.
|
||||||
input_md: metadata information of the input tensors.
|
input_md: metadata information of the input tensors.
|
||||||
output_md: metadata information of the output tensors.
|
output_md: metadata information of the output tensors.
|
||||||
|
input_process_units: a lists of metadata of the input process units [1].
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L655
|
||||||
Returns:
|
Returns:
|
||||||
A buffer of the metadata.
|
A buffer of the metadata.
|
||||||
|
|
||||||
|
@ -325,6 +357,8 @@ def _create_metadata_buffer(
|
||||||
subgraph_metadata = metadata_fb.SubGraphMetadataT()
|
subgraph_metadata = metadata_fb.SubGraphMetadataT()
|
||||||
subgraph_metadata.inputTensorMetadata = input_metadata
|
subgraph_metadata.inputTensorMetadata = input_metadata
|
||||||
subgraph_metadata.outputTensorMetadata = output_metadata
|
subgraph_metadata.outputTensorMetadata = output_metadata
|
||||||
|
if input_process_units:
|
||||||
|
subgraph_metadata.inputProcessUnits = input_process_units
|
||||||
|
|
||||||
# Create the whole model metadata.
|
# Create the whole model metadata.
|
||||||
if general_md is None:
|
if general_md is None:
|
||||||
|
@ -366,6 +400,7 @@ class MetadataWriter(object):
|
||||||
self._model_buffer = model_buffer
|
self._model_buffer = model_buffer
|
||||||
self._general_md = None
|
self._general_md = None
|
||||||
self._input_mds = []
|
self._input_mds = []
|
||||||
|
self._input_process_units = []
|
||||||
self._output_mds = []
|
self._output_mds = []
|
||||||
self._associated_files = []
|
self._associated_files = []
|
||||||
self._temp_folder = tempfile.TemporaryDirectory()
|
self._temp_folder = tempfile.TemporaryDirectory()
|
||||||
|
@ -416,7 +451,7 @@ class MetadataWriter(object):
|
||||||
description: Description of the input tensor.
|
description: Description of the input tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The MetaWriter instance, can be used for chained operation.
|
The MetadataWriter instance, can be used for chained operation.
|
||||||
|
|
||||||
[1]:
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389
|
||||||
|
@ -448,7 +483,7 @@ class MetadataWriter(object):
|
||||||
description: Description of the input tensor.
|
description: Description of the input tensor.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The MetaWriter instance, can be used for chained operation.
|
The MetadataWriter instance, can be used for chained operation.
|
||||||
|
|
||||||
[1]:
|
[1]:
|
||||||
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||||
|
@ -462,6 +497,63 @@ class MetadataWriter(object):
|
||||||
self._associated_files.append(regex_tokenizer.vocab_file_path)
|
self._associated_files.append(regex_tokenizer.vocab_file_path)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def add_bert_text_input(self, tokenizer: Union[BertTokenizer,
|
||||||
|
SentencePieceTokenizer],
|
||||||
|
ids_name: str, mask_name: str,
|
||||||
|
segment_name: str) -> 'MetadataWriter':
|
||||||
|
"""Adds an metadata for the text input with bert / sentencepiece tokenizer.
|
||||||
|
|
||||||
|
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||||
|
in the TFLite schema, which help to determine the tensor order when
|
||||||
|
populating metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: information of the tokenizer used to process the input string,
|
||||||
|
if any. Supported tokenziers are: `BertTokenizer` [1] and
|
||||||
|
`SentencePieceTokenizer` [2].
|
||||||
|
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||||
|
the input text.
|
||||||
|
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||||
|
real tokens and `0` for padding tokens.
|
||||||
|
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||||
|
first sequence, and `1` stands for the second sequence if exists.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The MetadataWriter instance, can be used for chained operation.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if the type tokenizer is not BertTokenizer or
|
||||||
|
SentencePieceTokenizer.
|
||||||
|
|
||||||
|
[1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||||
|
[2]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||||
|
"""
|
||||||
|
if isinstance(tokenizer, BertTokenizer):
|
||||||
|
tokenizer_md = metadata_info.BertTokenizerMd(
|
||||||
|
vocab_file_path=tokenizer.vocab_file_path)
|
||||||
|
elif isinstance(tokenizer, SentencePieceTokenizer):
|
||||||
|
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
|
||||||
|
sentence_piece_model_path=tokenizer.sentence_piece_model_path,
|
||||||
|
vocab_file_path=tokenizer.vocab_file_path)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'The type of tokenizer, {type(tokenizer)}, is unsupported')
|
||||||
|
bert_input_md = metadata_info.BertInputTensorsMd(
|
||||||
|
self._model_buffer,
|
||||||
|
ids_name,
|
||||||
|
mask_name,
|
||||||
|
segment_name,
|
||||||
|
tokenizer_md=tokenizer_md)
|
||||||
|
|
||||||
|
self._input_mds.extend(bert_input_md.input_md)
|
||||||
|
self._associated_files.extend(
|
||||||
|
bert_input_md.get_tokenizer_associated_files())
|
||||||
|
self._input_process_units.extend(
|
||||||
|
bert_input_md.create_input_process_unit_metadata())
|
||||||
|
return self
|
||||||
|
|
||||||
def add_classification_output(
|
def add_classification_output(
|
||||||
self,
|
self,
|
||||||
labels: Optional[Labels] = None,
|
labels: Optional[Labels] = None,
|
||||||
|
@ -546,7 +638,8 @@ class MetadataWriter(object):
|
||||||
model_buffer=self._model_buffer,
|
model_buffer=self._model_buffer,
|
||||||
general_md=self._general_md,
|
general_md=self._general_md,
|
||||||
input_md=self._input_mds,
|
input_md=self._input_mds,
|
||||||
output_md=self._output_mds)
|
output_md=self._output_mds,
|
||||||
|
input_process_units=self._input_process_units)
|
||||||
populator.load_metadata_buffer(metadata_buffer)
|
populator.load_metadata_buffer(metadata_buffer)
|
||||||
if self._associated_files:
|
if self._associated_files:
|
||||||
populator.load_associated_files(self._associated_files)
|
populator.load_associated_files(self._associated_files)
|
||||||
|
|
|
@ -14,11 +14,18 @@
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Writes metadata and label file to the Text classifier models."""
|
"""Writes metadata and label file to the Text classifier models."""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||||
|
|
||||||
_MODEL_NAME = "TextClassifier"
|
_MODEL_NAME = "TextClassifier"
|
||||||
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
|
_MODEL_DESCRIPTION = ("Classify the input text into a set of known categories.")
|
||||||
|
|
||||||
|
# The input tensor names of models created by Model Maker.
|
||||||
|
_DEFAULT_ID_NAME = "serving_default_input_word_ids:0"
|
||||||
|
_DEFAULT_MASK_NAME = "serving_default_input_mask:0"
|
||||||
|
_DEFAULT_SEGMENT_ID_NAME = "serving_default_input_type_ids:0"
|
||||||
|
|
||||||
|
|
||||||
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||||
"""MetadataWriter to write the metadata into the text classifier."""
|
"""MetadataWriter to write the metadata into the text classifier."""
|
||||||
|
@ -62,3 +69,51 @@ class MetadataWriter(metadata_writer.MetadataWriterBase):
|
||||||
writer.add_regex_text_input(regex_tokenizer)
|
writer.add_regex_text_input(regex_tokenizer)
|
||||||
writer.add_classification_output(labels)
|
writer.add_classification_output(labels)
|
||||||
return cls(writer)
|
return cls(writer)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_for_bert_model(
|
||||||
|
cls,
|
||||||
|
model_buffer: bytearray,
|
||||||
|
tokenizer: Union[metadata_writer.BertTokenizer,
|
||||||
|
metadata_writer.SentencePieceTokenizer],
|
||||||
|
labels: metadata_writer.Labels,
|
||||||
|
ids_name: str = _DEFAULT_ID_NAME,
|
||||||
|
mask_name: str = _DEFAULT_MASK_NAME,
|
||||||
|
segment_name: str = _DEFAULT_SEGMENT_ID_NAME,
|
||||||
|
) -> "MetadataWriter":
|
||||||
|
"""Creates MetadataWriter for models with {Bert/SentencePiece}Tokenizer.
|
||||||
|
|
||||||
|
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
|
||||||
|
in the TFLite schema, which help to determine the tensor order when
|
||||||
|
populating metadata. The default values come from Model Maker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_buffer: valid buffer of the model file.
|
||||||
|
tokenizer: information of the tokenizer used to process the input string,
|
||||||
|
if any. Supported tokenziers are: `BertTokenizer` [1] and
|
||||||
|
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer` [3],
|
||||||
|
refer to `create_for_regex_model`.
|
||||||
|
labels: an instance of Labels helper class used in the output
|
||||||
|
classification tensor [4].
|
||||||
|
ids_name: name of the ids tensor, which represents the tokenized ids of
|
||||||
|
the input text.
|
||||||
|
mask_name: name of the mask tensor, which represents the mask with `1` for
|
||||||
|
real tokens and `0` for padding tokens.
|
||||||
|
segment_name: name of the segment ids tensor, where `0` stands for the
|
||||||
|
first sequence, and `1` stands for the second sequence if exists. [1]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L477
|
||||||
|
[2]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L485
|
||||||
|
[3]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L500
|
||||||
|
[4]:
|
||||||
|
https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L99
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A MetadataWriter object.
|
||||||
|
"""
|
||||||
|
writer = metadata_writer.MetadataWriter(model_buffer)
|
||||||
|
writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION)
|
||||||
|
writer.add_bert_text_input(tokenizer, ids_name, mask_name, segment_name)
|
||||||
|
writer.add_classification_output(labels)
|
||||||
|
return cls(writer)
|
||||||
|
|
|
@ -367,6 +367,42 @@ class ScoreThresholdingMdTest(absltest.TestCase):
|
||||||
self.assertEqual(metadata_json, expected_json)
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
||||||
|
class BertTokenizerMdTest(absltest.TestCase):
|
||||||
|
|
||||||
|
_VOCAB_FILE = "vocab.txt"
|
||||||
|
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "bert_tokenizer_meta.json"))
|
||||||
|
|
||||||
|
def test_create_metadata_should_succeed(self):
|
||||||
|
tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE)
|
||||||
|
tokenizer_metadata = tokenizer_md.create_metadata()
|
||||||
|
|
||||||
|
metadata_json = _metadata.convert_to_json(
|
||||||
|
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
|
||||||
|
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
||||||
|
class SentencePieceTokenizerMdTest(absltest.TestCase):
|
||||||
|
|
||||||
|
_VOCAB_FILE = "vocab.txt"
|
||||||
|
_SP_MODEL = "sp.model"
|
||||||
|
_EXPECTED_TENSOR_JSON = test_utils.get_test_data_path(
|
||||||
|
os.path.join(_TEST_DATA_DIR, "sentence_piece_tokenizer_meta.json"))
|
||||||
|
|
||||||
|
def test_create_metadata_should_succeed(self):
|
||||||
|
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
|
||||||
|
self._SP_MODEL, self._VOCAB_FILE)
|
||||||
|
tokenizer_metadata = tokenizer_md.create_metadata()
|
||||||
|
|
||||||
|
metadata_json = _metadata.convert_to_json(
|
||||||
|
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
|
||||||
|
with open(self._EXPECTED_TENSOR_JSON, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
||||||
def _create_dummy_model_metadata_with_tensor(
|
def _create_dummy_model_metadata_with_tensor(
|
||||||
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
|
||||||
# Create a dummy model using the tensor metadata.
|
# Create a dummy model using the tensor metadata.
|
||||||
|
|
|
@ -21,28 +21,64 @@ from mediapipe.tasks.python.metadata.metadata_writers import text_classifier
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
|
_TEST_DIR = "mediapipe/tasks/testdata/metadata/"
|
||||||
_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
_REGEX_MODEL = test_utils.get_test_data_path(_TEST_DIR + "movie_review.tflite")
|
||||||
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
_LABEL_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||||
"movie_review_labels.txt")
|
"movie_review_labels.txt")
|
||||||
_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
_REGEX_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR + "regex_vocab.txt")
|
||||||
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
_DELIM_REGEX_PATTERN = r"[^\w\']+"
|
||||||
_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
_REGEX_JSON_FILE = test_utils.get_test_data_path("movie_review.json")
|
||||||
|
|
||||||
|
_BERT_MODEL = test_utils.get_test_data_path(
|
||||||
|
_TEST_DIR + "bert_text_classifier_no_metadata.tflite")
|
||||||
|
_BERT_VOCAB_FILE = test_utils.get_test_data_path(_TEST_DIR +
|
||||||
|
"mobilebert_vocab.txt")
|
||||||
|
_SP_MODEL_FILE = test_utils.get_test_data_path(_TEST_DIR + "30k-clean.model")
|
||||||
|
_BERT_JSON_FILE = test_utils.get_test_data_path(
|
||||||
|
_TEST_DIR + "bert_text_classifier_with_bert_tokenizer.json")
|
||||||
|
_SENTENCE_PIECE_JSON_FILE = test_utils.get_test_data_path(
|
||||||
|
_TEST_DIR + "bert_text_classifier_with_sentence_piece.json")
|
||||||
|
|
||||||
|
|
||||||
class TextClassifierTest(absltest.TestCase):
|
class TextClassifierTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_write_metadata(self,):
|
def test_write_metadata_for_regex_model(self):
|
||||||
with open(_MODEL, "rb") as f:
|
with open(_REGEX_MODEL, "rb") as f:
|
||||||
model_buffer = f.read()
|
model_buffer = f.read()
|
||||||
writer = text_classifier.MetadataWriter.create_for_regex_model(
|
writer = text_classifier.MetadataWriter.create_for_regex_model(
|
||||||
model_buffer,
|
model_buffer,
|
||||||
regex_tokenizer=metadata_writer.RegexTokenizer(
|
regex_tokenizer=metadata_writer.RegexTokenizer(
|
||||||
delim_regex_pattern=_DELIM_REGEX_PATTERN,
|
delim_regex_pattern=_DELIM_REGEX_PATTERN,
|
||||||
vocab_file_path=_VOCAB_FILE),
|
vocab_file_path=_REGEX_VOCAB_FILE),
|
||||||
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||||
_, metadata_json = writer.populate()
|
_, metadata_json = writer.populate()
|
||||||
|
|
||||||
with open(_JSON_FILE, "r") as f:
|
with open(_REGEX_JSON_FILE, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
def test_write_metadata_for_bert(self):
|
||||||
|
with open(_BERT_MODEL, "rb") as f:
|
||||||
|
model_buffer = f.read()
|
||||||
|
writer = text_classifier.MetadataWriter.create_for_bert_model(
|
||||||
|
model_buffer,
|
||||||
|
tokenizer=metadata_writer.BertTokenizer(_BERT_VOCAB_FILE),
|
||||||
|
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||||
|
_, metadata_json = writer.populate()
|
||||||
|
|
||||||
|
with open(_BERT_JSON_FILE, "r") as f:
|
||||||
|
expected_json = f.read()
|
||||||
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
def test_write_metadata_for_sentence_piece(self):
|
||||||
|
with open(_BERT_MODEL, "rb") as f:
|
||||||
|
model_buffer = f.read()
|
||||||
|
writer = text_classifier.MetadataWriter.create_for_bert_model(
|
||||||
|
model_buffer,
|
||||||
|
tokenizer=metadata_writer.SentencePieceTokenizer(_SP_MODEL_FILE),
|
||||||
|
labels=metadata_writer.Labels().add_from_file(_LABEL_FILE))
|
||||||
|
_, metadata_json = writer.populate()
|
||||||
|
|
||||||
|
with open(_SENTENCE_PIECE_JSON_FILE, "r") as f:
|
||||||
expected_json = f.read()
|
expected_json = f.read()
|
||||||
self.assertEqual(metadata_json, expected_json)
|
self.assertEqual(metadata_json, expected_json)
|
||||||
|
|
||||||
|
|
|
@ -94,3 +94,26 @@ py_test(
|
||||||
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "hand_landmarker_test",
|
||||||
|
srcs = ["hand_landmarker_test.py"],
|
||||||
|
data = [
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_images",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_models",
|
||||||
|
"//mediapipe/tasks/testdata/vision:test_protos",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:landmarks_detection_result_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:landmark",
|
||||||
|
"//mediapipe/tasks/python/components/containers:landmark_detection_result",
|
||||||
|
"//mediapipe/tasks/python/components/containers:rect",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
"//mediapipe/tasks/python/vision:hand_landmarker",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
"@com_google_protobuf//:protobuf_python",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
428
mediapipe/tasks/python/test/vision/hand_landmarker_test.py
Normal file
428
mediapipe/tasks/python/test/vision/hand_landmarker_test.py
Normal file
|
@ -0,0 +1,428 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for hand landmarker."""
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from google.protobuf import text_format
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.tasks.cc.components.containers.proto import landmarks_detection_result_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
|
from mediapipe.tasks.python.components.containers import landmark_detection_result as landmark_detection_result_module
|
||||||
|
from mediapipe.tasks.python.components.containers import rect as rect_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
from mediapipe.tasks.python.vision import hand_landmarker
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
|
_LandmarksDetectionResultProto = landmarks_detection_result_pb2.LandmarksDetectionResult
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_Rect = rect_module.Rect
|
||||||
|
_Landmark = landmark_module.Landmark
|
||||||
|
_NormalizedLandmark = landmark_module.NormalizedLandmark
|
||||||
|
_LandmarksDetectionResult = landmark_detection_result_module.LandmarksDetectionResult
|
||||||
|
_Image = image_module.Image
|
||||||
|
_HandLandmarker = hand_landmarker.HandLandmarker
|
||||||
|
_HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
||||||
|
_HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
||||||
|
_RUNNING_MODE = running_mode_module.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
|
||||||
|
_HAND_LANDMARKER_BUNDLE_ASSET_FILE = 'hand_landmarker.task'
|
||||||
|
_NO_HANDS_IMAGE = 'cats_and_dogs.jpg'
|
||||||
|
_TWO_HANDS_IMAGE = 'right_hands.jpg'
|
||||||
|
_THUMB_UP_IMAGE = 'thumb_up.jpg'
|
||||||
|
_THUMB_UP_LANDMARKS = 'thumb_up_landmarks.pbtxt'
|
||||||
|
_POINTING_UP_IMAGE = 'pointing_up.jpg'
|
||||||
|
_POINTING_UP_LANDMARKS = 'pointing_up_landmarks.pbtxt'
|
||||||
|
_POINTING_UP_ROTATED_IMAGE = 'pointing_up_rotated.jpg'
|
||||||
|
_POINTING_UP_ROTATED_LANDMARKS = 'pointing_up_rotated_landmarks.pbtxt'
|
||||||
|
_LANDMARKS_ERROR_TOLERANCE = 0.03
|
||||||
|
_HANDEDNESS_MARGIN = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
def _get_expected_hand_landmarker_result(
|
||||||
|
file_path: str) -> _HandLandmarkerResult:
|
||||||
|
landmarks_detection_result_file_path = test_utils.get_test_data_path(
|
||||||
|
file_path)
|
||||||
|
with open(landmarks_detection_result_file_path, 'rb') as f:
|
||||||
|
landmarks_detection_result_proto = _LandmarksDetectionResultProto()
|
||||||
|
# Use this if a .pb file is available.
|
||||||
|
# landmarks_detection_result_proto.ParseFromString(f.read())
|
||||||
|
text_format.Parse(f.read(), landmarks_detection_result_proto)
|
||||||
|
landmarks_detection_result = _LandmarksDetectionResult.create_from_pb2(
|
||||||
|
landmarks_detection_result_proto)
|
||||||
|
return _HandLandmarkerResult(
|
||||||
|
handedness=[landmarks_detection_result.categories],
|
||||||
|
hand_landmarks=[landmarks_detection_result.landmarks],
|
||||||
|
hand_world_landmarks=[landmarks_detection_result.world_landmarks])
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFileType(enum.Enum):
|
||||||
|
FILE_CONTENT = 1
|
||||||
|
FILE_NAME = 2
|
||||||
|
|
||||||
|
|
||||||
|
class HandLandmarkerTest(parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
self.test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(_THUMB_UP_IMAGE))
|
||||||
|
self.model_path = test_utils.get_test_data_path(
|
||||||
|
_HAND_LANDMARKER_BUNDLE_ASSET_FILE)
|
||||||
|
|
||||||
|
def _assert_actual_result_approximately_matches_expected_result(
|
||||||
|
self, actual_result: _HandLandmarkerResult,
|
||||||
|
expected_result: _HandLandmarkerResult):
|
||||||
|
# Expects to have the same number of hands detected.
|
||||||
|
self.assertLen(actual_result.hand_landmarks,
|
||||||
|
len(expected_result.hand_landmarks))
|
||||||
|
self.assertLen(actual_result.hand_world_landmarks,
|
||||||
|
len(expected_result.hand_world_landmarks))
|
||||||
|
self.assertLen(actual_result.handedness, len(expected_result.handedness))
|
||||||
|
# Actual landmarks match expected landmarks.
|
||||||
|
self.assertLen(actual_result.hand_landmarks[0],
|
||||||
|
len(expected_result.hand_landmarks[0]))
|
||||||
|
actual_landmarks = actual_result.hand_landmarks[0]
|
||||||
|
expected_landmarks = expected_result.hand_landmarks[0]
|
||||||
|
for i, rename_me in enumerate(actual_landmarks):
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
rename_me.x,
|
||||||
|
expected_landmarks[i].x,
|
||||||
|
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
rename_me.y,
|
||||||
|
expected_landmarks[i].y,
|
||||||
|
delta=_LANDMARKS_ERROR_TOLERANCE)
|
||||||
|
# Actual handedness matches expected handedness.
|
||||||
|
actual_top_handedness = actual_result.handedness[0][0]
|
||||||
|
expected_top_handedness = expected_result.handedness[0][0]
|
||||||
|
self.assertEqual(actual_top_handedness.index, expected_top_handedness.index)
|
||||||
|
self.assertEqual(actual_top_handedness.category_name,
|
||||||
|
expected_top_handedness.category_name)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
actual_top_handedness.score,
|
||||||
|
expected_top_handedness.score,
|
||||||
|
delta=_HANDEDNESS_MARGIN)
|
||||||
|
|
||||||
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with default option and valid model file successfully.
|
||||||
|
with _HandLandmarker.create_from_model_path(self.model_path) as landmarker:
|
||||||
|
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_path(self):
|
||||||
|
# Creates with options containing model file successfully.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||||
|
|
||||||
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
|
# Invalid empty model path.
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||||
|
base_options = _BaseOptions(
|
||||||
|
model_asset_path='/path/to/invalid/model.tflite')
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options)
|
||||||
|
_HandLandmarker.create_from_options(options)
|
||||||
|
|
||||||
|
def test_create_from_options_succeeds_with_valid_model_content(self):
|
||||||
|
# Creates with options containing model content successfully.
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=f.read())
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options)
|
||||||
|
landmarker = _HandLandmarker.create_from_options(options)
|
||||||
|
self.assertIsInstance(landmarker, _HandLandmarker)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME,
|
||||||
|
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||||
|
(ModelFileType.FILE_CONTENT,
|
||||||
|
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
|
||||||
|
def test_detect(self, model_file_type, expected_detection_result):
|
||||||
|
# Creates hand landmarker.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options)
|
||||||
|
landmarker = _HandLandmarker.create_from_options(options)
|
||||||
|
|
||||||
|
# Performs hand landmarks detection on the input.
|
||||||
|
detection_result = landmarker.detect(self.test_image)
|
||||||
|
# Comparing results.
|
||||||
|
self._assert_actual_result_approximately_matches_expected_result(
|
||||||
|
detection_result, expected_detection_result)
|
||||||
|
# Closes the hand landmarker explicitly when the hand landmarker is not used
|
||||||
|
# in a context.
|
||||||
|
landmarker.close()
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(ModelFileType.FILE_NAME,
|
||||||
|
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||||
|
(ModelFileType.FILE_CONTENT,
|
||||||
|
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)))
|
||||||
|
def test_detect_in_context(self, model_file_type, expected_detection_result):
|
||||||
|
# Creates hand landmarker.
|
||||||
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||||
|
with open(self.model_path, 'rb') as f:
|
||||||
|
model_content = f.read()
|
||||||
|
base_options = _BaseOptions(model_asset_buffer=model_content)
|
||||||
|
else:
|
||||||
|
# Should never happen
|
||||||
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Performs hand landmarks detection on the input.
|
||||||
|
detection_result = landmarker.detect(self.test_image)
|
||||||
|
# Comparing results.
|
||||||
|
self._assert_actual_result_approximately_matches_expected_result(
|
||||||
|
detection_result, expected_detection_result)
|
||||||
|
|
||||||
|
def test_detect_succeeds_with_num_hands(self):
|
||||||
|
# Creates hand landmarker.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options, num_hands=2)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Load the two hands image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(_TWO_HANDS_IMAGE))
|
||||||
|
# Performs hand landmarks detection on the input.
|
||||||
|
detection_result = landmarker.detect(test_image)
|
||||||
|
# Comparing results.
|
||||||
|
self.assertLen(detection_result.handedness, 2)
|
||||||
|
|
||||||
|
def test_detect_succeeds_with_rotation(self):
|
||||||
|
# Creates hand landmarker.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Load the pointing up rotated image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(_POINTING_UP_ROTATED_IMAGE))
|
||||||
|
# Set rotation parameters using ImageProcessingOptions.
|
||||||
|
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||||
|
# Performs hand landmarks detection on the input.
|
||||||
|
detection_result = landmarker.detect(test_image, image_processing_options)
|
||||||
|
expected_detection_result = _get_expected_hand_landmarker_result(
|
||||||
|
_POINTING_UP_ROTATED_LANDMARKS)
|
||||||
|
# Comparing results.
|
||||||
|
self._assert_actual_result_approximately_matches_expected_result(
|
||||||
|
detection_result, expected_detection_result)
|
||||||
|
|
||||||
|
def test_detect_fails_with_region_of_interest(self):
|
||||||
|
# Creates hand landmarker.
|
||||||
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
options = _HandLandmarkerOptions(base_options=base_options)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "This task doesn't support region-of-interest."):
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Set the `region_of_interest` parameter using `ImageProcessingOptions`.
|
||||||
|
image_processing_options = _ImageProcessingOptions(
|
||||||
|
region_of_interest=_Rect(0, 0, 1, 1))
|
||||||
|
# Attempt to perform hand landmarks detection on the cropped input.
|
||||||
|
landmarker.detect(self.test_image, image_processing_options)
|
||||||
|
|
||||||
|
def test_empty_detection_outputs(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path))
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
# Load the image with no hands.
|
||||||
|
no_hands_test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(_NO_HANDS_IMAGE))
|
||||||
|
# Performs hand landmarks detection on the input.
|
||||||
|
detection_result = landmarker.detect(no_hands_test_image)
|
||||||
|
self.assertEmpty(detection_result.hand_landmarks)
|
||||||
|
self.assertEmpty(detection_result.hand_world_landmarks)
|
||||||
|
self.assertEmpty(detection_result.handedness)
|
||||||
|
|
||||||
|
def test_missing_result_callback(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback must be provided'):
|
||||||
|
with _HandLandmarker.create_from_options(options) as unused_landmarker:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO))
|
||||||
|
def test_illegal_result_callback(self, running_mode):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=running_mode,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'result callback should not be provided'):
|
||||||
|
with _HandLandmarker.create_from_options(options) as unused_landmarker:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_calling_detect_for_video_in_image_mode(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
landmarker.detect_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_detect_async_in_image_mode(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.IMAGE)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
landmarker.detect_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_calling_detect_in_video_mode(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
landmarker.detect(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_detect_async_in_video_mode(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the live stream mode'):
|
||||||
|
landmarker.detect_async(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_detect_for_video_with_out_of_order_timestamp(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
unused_result = landmarker.detect_for_video(self.test_image, 1)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
landmarker.detect_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(_THUMB_UP_IMAGE, 0,
|
||||||
|
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||||
|
(_POINTING_UP_IMAGE, 0,
|
||||||
|
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
|
||||||
|
(_POINTING_UP_ROTATED_IMAGE, -90,
|
||||||
|
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
|
||||||
|
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
|
||||||
|
def test_detect_for_video(self, image_path, rotation, expected_result):
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(image_path))
|
||||||
|
# Set rotation parameters using ImageProcessingOptions.
|
||||||
|
image_processing_options = _ImageProcessingOptions(
|
||||||
|
rotation_degrees=rotation)
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
result = landmarker.detect_for_video(test_image, timestamp,
|
||||||
|
image_processing_options)
|
||||||
|
if result.hand_landmarks and result.hand_world_landmarks and result.handedness:
|
||||||
|
self._assert_actual_result_approximately_matches_expected_result(
|
||||||
|
result, expected_result)
|
||||||
|
else:
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
|
||||||
|
def test_calling_detect_in_live_stream_mode(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the image mode'):
|
||||||
|
landmarker.detect(self.test_image)
|
||||||
|
|
||||||
|
def test_calling_detect_for_video_in_live_stream_mode(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
r'not initialized with the video mode'):
|
||||||
|
landmarker.detect_for_video(self.test_image, 0)
|
||||||
|
|
||||||
|
def test_detect_async_calls_with_illegal_timestamp(self):
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=mock.MagicMock())
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
landmarker.detect_async(self.test_image, 100)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||||
|
landmarker.detect_async(self.test_image, 0)
|
||||||
|
|
||||||
|
@parameterized.parameters(
|
||||||
|
(_THUMB_UP_IMAGE, 0,
|
||||||
|
_get_expected_hand_landmarker_result(_THUMB_UP_LANDMARKS)),
|
||||||
|
(_POINTING_UP_IMAGE, 0,
|
||||||
|
_get_expected_hand_landmarker_result(_POINTING_UP_LANDMARKS)),
|
||||||
|
(_POINTING_UP_ROTATED_IMAGE, -90,
|
||||||
|
_get_expected_hand_landmarker_result(_POINTING_UP_ROTATED_LANDMARKS)),
|
||||||
|
(_NO_HANDS_IMAGE, 0, _HandLandmarkerResult([], [], [])))
|
||||||
|
def test_detect_async_calls(self, image_path, rotation, expected_result):
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(image_path))
|
||||||
|
# Set rotation parameters using ImageProcessingOptions.
|
||||||
|
image_processing_options = _ImageProcessingOptions(
|
||||||
|
rotation_degrees=rotation)
|
||||||
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
|
def check_result(result: _HandLandmarkerResult, output_image: _Image,
|
||||||
|
timestamp_ms: int):
|
||||||
|
if result.hand_landmarks and result.hand_world_landmarks and result.handedness:
|
||||||
|
self._assert_actual_result_approximately_matches_expected_result(
|
||||||
|
result, expected_result)
|
||||||
|
else:
|
||||||
|
self.assertEqual(result, expected_result)
|
||||||
|
self.assertTrue(
|
||||||
|
np.array_equal(output_image.numpy_view(), test_image.numpy_view()))
|
||||||
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
|
options = _HandLandmarkerOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
result_callback=check_result)
|
||||||
|
with _HandLandmarker.create_from_options(options) as landmarker:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
landmarker.detect_async(test_image, timestamp, image_processing_options)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
|
@ -79,6 +79,28 @@ py_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "image_embedder",
|
||||||
|
srcs = [
|
||||||
|
"image_embedder.py",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/python:_framework_bindings",
|
||||||
|
"//mediapipe/python:packet_creator",
|
||||||
|
"//mediapipe/python:packet_getter",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
||||||
|
"//mediapipe/tasks/python/components/containers:embedding_result",
|
||||||
|
"//mediapipe/tasks/python/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
"//mediapipe/tasks/python/vision/core:base_vision_task_api",
|
||||||
|
"//mediapipe/tasks/python/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/python/vision/core:vision_task_running_mode",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "gesture_recognizer",
|
name = "gesture_recognizer",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -104,18 +126,19 @@ py_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "image_embedder",
|
name = "hand_landmarker",
|
||||||
srcs = [
|
srcs = [
|
||||||
"image_embedder.py",
|
"hand_landmarker.py",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_py_pb2",
|
||||||
|
"//mediapipe/framework/formats:landmark_py_pb2",
|
||||||
"//mediapipe/python:_framework_bindings",
|
"//mediapipe/python:_framework_bindings",
|
||||||
"//mediapipe/python:packet_creator",
|
"//mediapipe/python:packet_creator",
|
||||||
"//mediapipe/python:packet_getter",
|
"//mediapipe/python:packet_getter",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_py_pb2",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_py_pb2",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_py_pb2",
|
"//mediapipe/tasks/python/components/containers:category",
|
||||||
"//mediapipe/tasks/python/components/containers:embedding_result",
|
"//mediapipe/tasks/python/components/containers:landmark",
|
||||||
"//mediapipe/tasks/python/components/processors:classifier_options",
|
|
||||||
"//mediapipe/tasks/python/core:base_options",
|
"//mediapipe/tasks/python/core:base_options",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
"//mediapipe/tasks/python/core:task_info",
|
"//mediapipe/tasks/python/core:task_info",
|
||||||
|
|
|
@ -16,12 +16,17 @@
|
||||||
|
|
||||||
import mediapipe.tasks.python.vision.core
|
import mediapipe.tasks.python.vision.core
|
||||||
import mediapipe.tasks.python.vision.gesture_recognizer
|
import mediapipe.tasks.python.vision.gesture_recognizer
|
||||||
|
import mediapipe.tasks.python.vision.hand_landmarker
|
||||||
import mediapipe.tasks.python.vision.image_classifier
|
import mediapipe.tasks.python.vision.image_classifier
|
||||||
import mediapipe.tasks.python.vision.image_segmenter
|
import mediapipe.tasks.python.vision.image_segmenter
|
||||||
import mediapipe.tasks.python.vision.object_detector
|
import mediapipe.tasks.python.vision.object_detector
|
||||||
|
|
||||||
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||||
GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
|
GestureRecognizerOptions = gesture_recognizer.GestureRecognizerOptions
|
||||||
|
GestureRecognizerResult = gesture_recognizer.GestureRecognizerResult
|
||||||
|
HandLandmarker = hand_landmarker.HandLandmarker
|
||||||
|
HandLandmarkerOptions = hand_landmarker.HandLandmarkerOptions
|
||||||
|
HandLandmarkerResult = hand_landmarker.HandLandmarkerResult
|
||||||
ImageClassifier = image_classifier.ImageClassifier
|
ImageClassifier = image_classifier.ImageClassifier
|
||||||
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
ImageClassifierOptions = image_classifier.ImageClassifierOptions
|
||||||
ImageSegmenter = image_segmenter.ImageSegmenter
|
ImageSegmenter = image_segmenter.ImageSegmenter
|
||||||
|
@ -33,6 +38,7 @@ RunningMode = core.vision_task_running_mode.VisionTaskRunningMode
|
||||||
# Remove unnecessary modules to avoid duplication in API docs.
|
# Remove unnecessary modules to avoid duplication in API docs.
|
||||||
del core
|
del core
|
||||||
del gesture_recognizer
|
del gesture_recognizer
|
||||||
|
del hand_landmarker
|
||||||
del image_classifier
|
del image_classifier
|
||||||
del image_segmenter
|
del image_segmenter
|
||||||
del object_detector
|
del object_detector
|
||||||
|
|
|
@ -59,7 +59,7 @@ _GESTURE_DEFAULT_INDEX = -1
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class GestureRecognitionResult:
|
class GestureRecognizerResult:
|
||||||
"""The gesture recognition result from GestureRecognizer, where each vector element represents a single hand detected in the image.
|
"""The gesture recognition result from GestureRecognizer, where each vector element represents a single hand detected in the image.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
@ -79,8 +79,8 @@ class GestureRecognitionResult:
|
||||||
|
|
||||||
def _build_recognition_result(
|
def _build_recognition_result(
|
||||||
output_packets: Mapping[str,
|
output_packets: Mapping[str,
|
||||||
packet_module.Packet]) -> GestureRecognitionResult:
|
packet_module.Packet]) -> GestureRecognizerResult:
|
||||||
"""Consturcts a `GestureRecognitionResult` from output packets."""
|
"""Consturcts a `GestureRecognizerResult` from output packets."""
|
||||||
gestures_proto_list = packet_getter.get_proto_list(
|
gestures_proto_list = packet_getter.get_proto_list(
|
||||||
output_packets[_HAND_GESTURE_STREAM_NAME])
|
output_packets[_HAND_GESTURE_STREAM_NAME])
|
||||||
handedness_proto_list = packet_getter.get_proto_list(
|
handedness_proto_list = packet_getter.get_proto_list(
|
||||||
|
@ -122,23 +122,25 @@ def _build_recognition_result(
|
||||||
for proto in hand_landmarks_proto_list:
|
for proto in hand_landmarks_proto_list:
|
||||||
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
||||||
hand_landmarks.MergeFrom(proto)
|
hand_landmarks.MergeFrom(proto)
|
||||||
hand_landmarks_results.append([
|
hand_landmarks_list = []
|
||||||
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark)
|
for hand_landmark in hand_landmarks.landmark:
|
||||||
for hand_landmark in hand_landmarks.landmark
|
hand_landmarks_list.append(
|
||||||
])
|
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
|
||||||
|
hand_landmarks_results.append(hand_landmarks_list)
|
||||||
|
|
||||||
hand_world_landmarks_results = []
|
hand_world_landmarks_results = []
|
||||||
for proto in hand_world_landmarks_proto_list:
|
for proto in hand_world_landmarks_proto_list:
|
||||||
hand_world_landmarks = landmark_pb2.LandmarkList()
|
hand_world_landmarks = landmark_pb2.LandmarkList()
|
||||||
hand_world_landmarks.MergeFrom(proto)
|
hand_world_landmarks.MergeFrom(proto)
|
||||||
hand_world_landmarks_results.append([
|
hand_world_landmarks_list = []
|
||||||
landmark_module.Landmark.create_from_pb2(hand_world_landmark)
|
for hand_world_landmark in hand_world_landmarks.landmark:
|
||||||
for hand_world_landmark in hand_world_landmarks.landmark
|
hand_world_landmarks_list.append(
|
||||||
])
|
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
|
||||||
|
hand_world_landmarks_results.append(hand_world_landmarks_list)
|
||||||
|
|
||||||
return GestureRecognitionResult(gesture_results, handedness_results,
|
return GestureRecognizerResult(gesture_results, handedness_results,
|
||||||
hand_landmarks_results,
|
hand_landmarks_results,
|
||||||
hand_world_landmarks_results)
|
hand_world_landmarks_results)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
@ -183,7 +185,7 @@ class GestureRecognizerOptions:
|
||||||
custom_gesture_classifier_options: Optional[
|
custom_gesture_classifier_options: Optional[
|
||||||
_ClassifierOptions] = _ClassifierOptions()
|
_ClassifierOptions] = _ClassifierOptions()
|
||||||
result_callback: Optional[Callable[
|
result_callback: Optional[Callable[
|
||||||
[GestureRecognitionResult, image_module.Image, int], None]] = None
|
[GestureRecognizerResult, image_module.Image, int], None]] = None
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
|
def to_pb2(self) -> _GestureRecognizerGraphOptionsProto:
|
||||||
|
@ -264,7 +266,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
|
empty_packet = output_packets[_HAND_GESTURE_STREAM_NAME]
|
||||||
options.result_callback(
|
options.result_callback(
|
||||||
GestureRecognitionResult([], [], [], []), image,
|
GestureRecognizerResult([], [], [], []), image,
|
||||||
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -299,7 +301,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
self,
|
self,
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> GestureRecognitionResult:
|
) -> GestureRecognizerResult:
|
||||||
"""Performs hand gesture recognition on the given image.
|
"""Performs hand gesture recognition on the given image.
|
||||||
|
|
||||||
Only use this method when the GestureRecognizer is created with the image
|
Only use this method when the GestureRecognizer is created with the image
|
||||||
|
@ -330,7 +332,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
})
|
})
|
||||||
|
|
||||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
return GestureRecognitionResult([], [], [], [])
|
return GestureRecognizerResult([], [], [], [])
|
||||||
|
|
||||||
return _build_recognition_result(output_packets)
|
return _build_recognition_result(output_packets)
|
||||||
|
|
||||||
|
@ -339,7 +341,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
timestamp_ms: int,
|
timestamp_ms: int,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
) -> GestureRecognitionResult:
|
) -> GestureRecognizerResult:
|
||||||
"""Performs gesture recognition on the provided video frame.
|
"""Performs gesture recognition on the provided video frame.
|
||||||
|
|
||||||
Only use this method when the GestureRecognizer is created with the video
|
Only use this method when the GestureRecognizer is created with the video
|
||||||
|
@ -374,7 +376,7 @@ class GestureRecognizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
})
|
})
|
||||||
|
|
||||||
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
if output_packets[_HAND_GESTURE_STREAM_NAME].is_empty():
|
||||||
return GestureRecognitionResult([], [], [], [])
|
return GestureRecognizerResult([], [], [], [])
|
||||||
|
|
||||||
return _build_recognition_result(output_packets)
|
return _build_recognition_result(output_packets)
|
||||||
|
|
||||||
|
|
379
mediapipe/tasks/python/vision/hand_landmarker.py
Normal file
379
mediapipe/tasks/python/vision/hand_landmarker.py
Normal file
|
@ -0,0 +1,379 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe hand landmarker task."""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Callable, Mapping, Optional, List
|
||||||
|
|
||||||
|
from mediapipe.framework.formats import classification_pb2
|
||||||
|
from mediapipe.framework.formats import landmark_pb2
|
||||||
|
from mediapipe.python import packet_creator
|
||||||
|
from mediapipe.python import packet_getter
|
||||||
|
from mediapipe.python._framework_bindings import image as image_module
|
||||||
|
from mediapipe.python._framework_bindings import packet as packet_module
|
||||||
|
from mediapipe.tasks.cc.vision.hand_landmarker.proto import hand_landmarker_graph_options_pb2
|
||||||
|
from mediapipe.tasks.python.components.containers import category as category_module
|
||||||
|
from mediapipe.tasks.python.components.containers import landmark as landmark_module
|
||||||
|
from mediapipe.tasks.python.core import base_options as base_options_module
|
||||||
|
from mediapipe.tasks.python.core import task_info as task_info_module
|
||||||
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module
|
||||||
|
|
||||||
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
|
_HandLandmarkerGraphOptionsProto = hand_landmarker_graph_options_pb2.HandLandmarkerGraphOptions
|
||||||
|
_RunningMode = running_mode_module.VisionTaskRunningMode
|
||||||
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
|
_IMAGE_TAG = 'IMAGE'
|
||||||
|
_NORM_RECT_STREAM_NAME = 'norm_rect_in'
|
||||||
|
_NORM_RECT_TAG = 'NORM_RECT'
|
||||||
|
_HANDEDNESS_STREAM_NAME = 'handedness'
|
||||||
|
_HANDEDNESS_TAG = 'HANDEDNESS'
|
||||||
|
_HAND_LANDMARKS_STREAM_NAME = 'landmarks'
|
||||||
|
_HAND_LANDMARKS_TAG = 'LANDMARKS'
|
||||||
|
_HAND_WORLD_LANDMARKS_STREAM_NAME = 'world_landmarks'
|
||||||
|
_HAND_WORLD_LANDMARKS_TAG = 'WORLD_LANDMARKS'
|
||||||
|
_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph'
|
||||||
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HandLandmarkerResult:
|
||||||
|
"""The hand landmarks result from HandLandmarker, where each vector element represents a single hand detected in the image.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
handedness: Classification of handedness.
|
||||||
|
hand_landmarks: Detected hand landmarks in normalized image coordinates.
|
||||||
|
hand_world_landmarks: Detected hand landmarks in world coordinates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
handedness: List[List[category_module.Category]]
|
||||||
|
hand_landmarks: List[List[landmark_module.NormalizedLandmark]]
|
||||||
|
hand_world_landmarks: List[List[landmark_module.Landmark]]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_landmarker_result(
|
||||||
|
output_packets: Mapping[str, packet_module.Packet]) -> HandLandmarkerResult:
|
||||||
|
"""Constructs a `HandLandmarksDetectionResult` from output packets."""
|
||||||
|
handedness_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HANDEDNESS_STREAM_NAME])
|
||||||
|
hand_landmarks_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_LANDMARKS_STREAM_NAME])
|
||||||
|
hand_world_landmarks_proto_list = packet_getter.get_proto_list(
|
||||||
|
output_packets[_HAND_WORLD_LANDMARKS_STREAM_NAME])
|
||||||
|
|
||||||
|
handedness_results = []
|
||||||
|
for proto in handedness_proto_list:
|
||||||
|
handedness_categories = []
|
||||||
|
handedness_classifications = classification_pb2.ClassificationList()
|
||||||
|
handedness_classifications.MergeFrom(proto)
|
||||||
|
for handedness in handedness_classifications.classification:
|
||||||
|
handedness_categories.append(
|
||||||
|
category_module.Category(
|
||||||
|
index=handedness.index,
|
||||||
|
score=handedness.score,
|
||||||
|
display_name=handedness.display_name,
|
||||||
|
category_name=handedness.label))
|
||||||
|
handedness_results.append(handedness_categories)
|
||||||
|
|
||||||
|
hand_landmarks_results = []
|
||||||
|
for proto in hand_landmarks_proto_list:
|
||||||
|
hand_landmarks = landmark_pb2.NormalizedLandmarkList()
|
||||||
|
hand_landmarks.MergeFrom(proto)
|
||||||
|
hand_landmarks_list = []
|
||||||
|
for hand_landmark in hand_landmarks.landmark:
|
||||||
|
hand_landmarks_list.append(
|
||||||
|
landmark_module.NormalizedLandmark.create_from_pb2(hand_landmark))
|
||||||
|
hand_landmarks_results.append(hand_landmarks_list)
|
||||||
|
|
||||||
|
hand_world_landmarks_results = []
|
||||||
|
for proto in hand_world_landmarks_proto_list:
|
||||||
|
hand_world_landmarks = landmark_pb2.LandmarkList()
|
||||||
|
hand_world_landmarks.MergeFrom(proto)
|
||||||
|
hand_world_landmarks_list = []
|
||||||
|
for hand_world_landmark in hand_world_landmarks.landmark:
|
||||||
|
hand_world_landmarks_list.append(
|
||||||
|
landmark_module.Landmark.create_from_pb2(hand_world_landmark))
|
||||||
|
hand_world_landmarks_results.append(hand_world_landmarks_list)
|
||||||
|
|
||||||
|
return HandLandmarkerResult(handedness_results, hand_landmarks_results,
|
||||||
|
hand_world_landmarks_results)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HandLandmarkerOptions:
|
||||||
|
"""Options for the hand landmarker task.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
base_options: Base options for the hand landmarker task.
|
||||||
|
running_mode: The running mode of the task. Default to the image mode.
|
||||||
|
HandLandmarker has three running modes: 1) The image mode for detecting
|
||||||
|
hand landmarks on single image inputs. 2) The video mode for detecting
|
||||||
|
hand landmarks on the decoded frames of a video. 3) The live stream mode
|
||||||
|
for detecting hand landmarks on the live stream of input data, such as
|
||||||
|
from camera. In this mode, the "result_callback" below must be specified
|
||||||
|
to receive the detection results asynchronously.
|
||||||
|
num_hands: The maximum number of hands can be detected by the hand
|
||||||
|
landmarker.
|
||||||
|
min_hand_detection_confidence: The minimum confidence score for the hand
|
||||||
|
detection to be considered successful.
|
||||||
|
min_hand_presence_confidence: The minimum confidence score of hand presence
|
||||||
|
score in the hand landmark detection.
|
||||||
|
min_tracking_confidence: The minimum confidence score for the hand tracking
|
||||||
|
to be considered successful.
|
||||||
|
result_callback: The user-defined result callback for processing live stream
|
||||||
|
data. The result callback should only be specified when the running mode
|
||||||
|
is set to the live stream mode.
|
||||||
|
"""
|
||||||
|
base_options: _BaseOptions
|
||||||
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
|
num_hands: Optional[int] = 1
|
||||||
|
min_hand_detection_confidence: Optional[float] = 0.5
|
||||||
|
min_hand_presence_confidence: Optional[float] = 0.5
|
||||||
|
min_tracking_confidence: Optional[float] = 0.5
|
||||||
|
result_callback: Optional[Callable[
|
||||||
|
[HandLandmarkerResult, image_module.Image, int], None]] = None
|
||||||
|
|
||||||
|
@doc_controls.do_not_generate_docs
|
||||||
|
def to_pb2(self) -> _HandLandmarkerGraphOptionsProto:
|
||||||
|
"""Generates an HandLandmarkerGraphOptions protobuf object."""
|
||||||
|
base_options_proto = self.base_options.to_pb2()
|
||||||
|
base_options_proto.use_stream_mode = False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
|
|
||||||
|
# Initialize the hand landmarker options from base options.
|
||||||
|
hand_landmarker_options_proto = _HandLandmarkerGraphOptionsProto(
|
||||||
|
base_options=base_options_proto)
|
||||||
|
hand_landmarker_options_proto.min_tracking_confidence = self.min_tracking_confidence
|
||||||
|
hand_landmarker_options_proto.hand_detector_graph_options.num_hands = self.num_hands
|
||||||
|
hand_landmarker_options_proto.hand_detector_graph_options.min_detection_confidence = self.min_hand_detection_confidence
|
||||||
|
hand_landmarker_options_proto.hand_landmarks_detector_graph_options.min_detection_confidence = self.min_hand_presence_confidence
|
||||||
|
return hand_landmarker_options_proto
|
||||||
|
|
||||||
|
|
||||||
|
class HandLandmarker(base_vision_task_api.BaseVisionTaskApi):
|
||||||
|
"""Class that performs hand landmarks detection on images."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_model_path(cls, model_path: str) -> 'HandLandmarker':
|
||||||
|
"""Creates an `HandLandmarker` object from a TensorFlow Lite model and the default `HandLandmarkerOptions`.
|
||||||
|
|
||||||
|
Note that the created `HandLandmarker` instance is in image mode, for
|
||||||
|
detecting hand landmarks on single image inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`HandLandmarker` object that's created from the model file and the
|
||||||
|
default `HandLandmarkerOptions`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `HandLandmarker` object from the
|
||||||
|
provided file such as invalid file path.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
base_options = _BaseOptions(model_asset_path=model_path)
|
||||||
|
options = HandLandmarkerOptions(
|
||||||
|
base_options=base_options, running_mode=_RunningMode.IMAGE)
|
||||||
|
return cls.create_from_options(options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_from_options(cls,
|
||||||
|
options: HandLandmarkerOptions) -> 'HandLandmarker':
|
||||||
|
"""Creates the `HandLandmarker` object from hand landmarker options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
options: Options for the hand landmarker task.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`HandLandmarker` object that's created from `options`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If failed to create `HandLandmarker` object from
|
||||||
|
`HandLandmarkerOptions` such as missing the model.
|
||||||
|
RuntimeError: If other types of error occurred.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def packets_callback(output_packets: Mapping[str, packet_module.Packet]):
|
||||||
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
|
return
|
||||||
|
|
||||||
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
|
|
||||||
|
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||||
|
empty_packet = output_packets[_HAND_LANDMARKS_STREAM_NAME]
|
||||||
|
options.result_callback(
|
||||||
|
HandLandmarkerResult([], [], []), image,
|
||||||
|
empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
return
|
||||||
|
|
||||||
|
hand_landmarks_detection_result = _build_landmarker_result(output_packets)
|
||||||
|
timestamp = output_packets[_HAND_LANDMARKS_STREAM_NAME].timestamp
|
||||||
|
options.result_callback(hand_landmarks_detection_result, image,
|
||||||
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
|
||||||
|
task_info = _TaskInfo(
|
||||||
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
|
input_streams=[
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
|
],
|
||||||
|
output_streams=[
|
||||||
|
':'.join([_HANDEDNESS_TAG, _HANDEDNESS_STREAM_NAME]),
|
||||||
|
':'.join([_HAND_LANDMARKS_TAG, _HAND_LANDMARKS_STREAM_NAME]),
|
||||||
|
':'.join([
|
||||||
|
_HAND_WORLD_LANDMARKS_TAG, _HAND_WORLD_LANDMARKS_STREAM_NAME
|
||||||
|
]), ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME])
|
||||||
|
],
|
||||||
|
task_options=options)
|
||||||
|
return cls(
|
||||||
|
task_info.generate_graph_config(
|
||||||
|
enable_flow_limiting=options.running_mode ==
|
||||||
|
_RunningMode.LIVE_STREAM), options.running_mode,
|
||||||
|
packets_callback if options.result_callback else None)
|
||||||
|
|
||||||
|
def detect(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> HandLandmarkerResult:
|
||||||
|
"""Performs hand landmarks detection on the given image.
|
||||||
|
|
||||||
|
Only use this method when the HandLandmarker is created with the image
|
||||||
|
running mode.
|
||||||
|
|
||||||
|
The image can be of any size with format RGB or RGBA.
|
||||||
|
TODO: Describes how the input image will be preprocessed after the yuv
|
||||||
|
support is implemented.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The hand landmarks detection results.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If hand landmarker detection failed to run.
|
||||||
|
"""
|
||||||
|
normalized_rect = self.convert_to_normalized_rect(
|
||||||
|
image_processing_options, roi_allowed=False)
|
||||||
|
output_packets = self._process_image_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image),
|
||||||
|
_NORM_RECT_STREAM_NAME:
|
||||||
|
packet_creator.create_proto(normalized_rect.to_pb2())
|
||||||
|
})
|
||||||
|
|
||||||
|
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||||
|
return HandLandmarkerResult([], [], [])
|
||||||
|
|
||||||
|
return _build_landmarker_result(output_packets)
|
||||||
|
|
||||||
|
def detect_for_video(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
timestamp_ms: int,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> HandLandmarkerResult:
|
||||||
|
"""Performs hand landmarks detection on the provided video frame.
|
||||||
|
|
||||||
|
Only use this method when the HandLandmarker is created with the video
|
||||||
|
running mode.
|
||||||
|
|
||||||
|
Only use this method when the HandLandmarker is created with the video
|
||||||
|
running mode. It's required to provide the video frame's timestamp (in
|
||||||
|
milliseconds) along with the video frame. The input timestamps should be
|
||||||
|
monotonically increasing for adjacent calls of this method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The hand landmarks detection results.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any of the input arguments is invalid.
|
||||||
|
RuntimeError: If hand landmarker detection failed to run.
|
||||||
|
"""
|
||||||
|
normalized_rect = self.convert_to_normalized_rect(
|
||||||
|
image_processing_options, roi_allowed=False)
|
||||||
|
output_packets = self._process_video_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
|
_NORM_RECT_STREAM_NAME:
|
||||||
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
})
|
||||||
|
|
||||||
|
if output_packets[_HAND_LANDMARKS_STREAM_NAME].is_empty():
|
||||||
|
return HandLandmarkerResult([], [], [])
|
||||||
|
|
||||||
|
return _build_landmarker_result(output_packets)
|
||||||
|
|
||||||
|
def detect_async(
|
||||||
|
self,
|
||||||
|
image: image_module.Image,
|
||||||
|
timestamp_ms: int,
|
||||||
|
image_processing_options: Optional[_ImageProcessingOptions] = None
|
||||||
|
) -> None:
|
||||||
|
"""Sends live image data to perform hand landmarks detection.
|
||||||
|
|
||||||
|
The results will be available via the "result_callback" provided in the
|
||||||
|
HandLandmarkerOptions. Only use this method when the HandLandmarker is
|
||||||
|
created with the live stream running mode.
|
||||||
|
|
||||||
|
Only use this method when the HandLandmarker is created with the live
|
||||||
|
stream running mode. The input timestamps should be monotonically increasing
|
||||||
|
for adjacent calls of this method. This method will return immediately after
|
||||||
|
the input image is accepted. The results will be available via the
|
||||||
|
`result_callback` provided in the `HandLandmarkerOptions`. The
|
||||||
|
`detect_async` method is designed to process live stream data such as
|
||||||
|
camera input. To lower the overall latency, hand landmarker may drop the
|
||||||
|
input images if needed. In other words, it's not guaranteed to have output
|
||||||
|
per input image.
|
||||||
|
|
||||||
|
The `result_callback` provides:
|
||||||
|
- The hand landmarks detection results.
|
||||||
|
- The input image that the hand landmarker runs on.
|
||||||
|
- The input timestamp in milliseconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: MediaPipe Image.
|
||||||
|
timestamp_ms: The timestamp of the input image in milliseconds.
|
||||||
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the current input timestamp is smaller than what the
|
||||||
|
hand landmarker has already processed.
|
||||||
|
"""
|
||||||
|
normalized_rect = self.convert_to_normalized_rect(
|
||||||
|
image_processing_options, roi_allowed=False)
|
||||||
|
self._send_live_stream_data({
|
||||||
|
_IMAGE_IN_STREAM_NAME:
|
||||||
|
packet_creator.create_image(image).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
|
_NORM_RECT_STREAM_NAME:
|
||||||
|
packet_creator.create_proto(normalized_rect.to_pb2()).at(
|
||||||
|
timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND)
|
||||||
|
})
|
14
mediapipe/tasks/testdata/metadata/BUILD
vendored
14
mediapipe/tasks/testdata/metadata/BUILD
vendored
|
@ -23,10 +23,13 @@ package(
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_files(srcs = [
|
mediapipe_files(srcs = [
|
||||||
|
"30k-clean.model",
|
||||||
|
"bert_text_classifier_no_metadata.tflite",
|
||||||
"mobile_ica_8bit-with-metadata.tflite",
|
"mobile_ica_8bit-with-metadata.tflite",
|
||||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||||
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
"mobile_object_classifier_v0_2_3-metadata-no-name.tflite",
|
||||||
|
"mobilebert_vocab.txt",
|
||||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||||
"mobilenet_v2_1.0_224_quant.tflite",
|
"mobilenet_v2_1.0_224_quant.tflite",
|
||||||
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
"mobilenet_v2_1.0_224_quant_without_metadata.tflite",
|
||||||
|
@ -60,11 +63,17 @@ exports_files([
|
||||||
"movie_review_labels.txt",
|
"movie_review_labels.txt",
|
||||||
"regex_vocab.txt",
|
"regex_vocab.txt",
|
||||||
"movie_review.json",
|
"movie_review.json",
|
||||||
|
"bert_tokenizer_meta.json",
|
||||||
|
"bert_text_classifier_with_sentence_piece.json",
|
||||||
|
"sentence_piece_tokenizer_meta.json",
|
||||||
|
"bert_text_classifier_with_bert_tokenizer.json",
|
||||||
])
|
])
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "model_files",
|
name = "model_files",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"30k-clean.model",
|
||||||
|
"bert_text_classifier_no_metadata.tflite",
|
||||||
"mobile_ica_8bit-with-metadata.tflite",
|
"mobile_ica_8bit-with-metadata.tflite",
|
||||||
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
"mobile_ica_8bit-with-unsupported-metadata-version.tflite",
|
||||||
"mobile_ica_8bit-without-model-metadata.tflite",
|
"mobile_ica_8bit-without-model-metadata.tflite",
|
||||||
|
@ -81,6 +90,9 @@ filegroup(
|
||||||
name = "data_files",
|
name = "data_files",
|
||||||
srcs = [
|
srcs = [
|
||||||
"associated_file_meta.json",
|
"associated_file_meta.json",
|
||||||
|
"bert_text_classifier_with_bert_tokenizer.json",
|
||||||
|
"bert_text_classifier_with_sentence_piece.json",
|
||||||
|
"bert_tokenizer_meta.json",
|
||||||
"bounding_box_tensor_meta.json",
|
"bounding_box_tensor_meta.json",
|
||||||
"classification_tensor_float_meta.json",
|
"classification_tensor_float_meta.json",
|
||||||
"classification_tensor_uint8_meta.json",
|
"classification_tensor_uint8_meta.json",
|
||||||
|
@ -96,6 +108,7 @@ filegroup(
|
||||||
"input_text_tensor_default_meta.json",
|
"input_text_tensor_default_meta.json",
|
||||||
"input_text_tensor_meta.json",
|
"input_text_tensor_meta.json",
|
||||||
"labels.txt",
|
"labels.txt",
|
||||||
|
"mobilebert_vocab.txt",
|
||||||
"mobilenet_v2_1.0_224.json",
|
"mobilenet_v2_1.0_224.json",
|
||||||
"mobilenet_v2_1.0_224_quant.json",
|
"mobilenet_v2_1.0_224_quant.json",
|
||||||
"movie_review.json",
|
"movie_review.json",
|
||||||
|
@ -105,5 +118,6 @@ filegroup(
|
||||||
"score_calibration_file_meta.json",
|
"score_calibration_file_meta.json",
|
||||||
"score_calibration_tensor_meta.json",
|
"score_calibration_tensor_meta.json",
|
||||||
"score_thresholding_meta.json",
|
"score_thresholding_meta.json",
|
||||||
|
"sentence_piece_tokenizer_meta.json",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
84
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_bert_tokenizer.json
vendored
Normal file
84
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_bert_tokenizer.json
vendored
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
{
|
||||||
|
"name": "TextClassifier",
|
||||||
|
"description": "Classify the input text into a set of known categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "ids",
|
||||||
|
"description": "Tokenized ids of the input text.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "segment_ids",
|
||||||
|
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mask",
|
||||||
|
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input_process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "BertTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "mobilebert_vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.1.0"
|
||||||
|
}
|
83
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_sentence_piece.json
vendored
Normal file
83
mediapipe/tasks/testdata/metadata/bert_text_classifier_with_sentence_piece.json
vendored
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
{
|
||||||
|
"name": "TextClassifier",
|
||||||
|
"description": "Classify the input text into a set of known categories.",
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "ids",
|
||||||
|
"description": "Tokenized ids of the input text.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "segment_ids",
|
||||||
|
"description": "0 for the first sequence, 1 for the second sequence if exists.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "mask",
|
||||||
|
"description": "Mask with 1 for real tokens and 0 for padding tokens.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"output_tensor_metadata": [
|
||||||
|
{
|
||||||
|
"name": "score",
|
||||||
|
"description": "Score of the labels respectively.",
|
||||||
|
"content": {
|
||||||
|
"content_properties_type": "FeatureProperties",
|
||||||
|
"content_properties": {
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stats": {
|
||||||
|
"max": [
|
||||||
|
1.0
|
||||||
|
],
|
||||||
|
"min": [
|
||||||
|
0.0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"associated_files": [
|
||||||
|
{
|
||||||
|
"name": "labels.txt",
|
||||||
|
"description": "Labels for categories that the model can recognize.",
|
||||||
|
"type": "TENSOR_AXIS_LABELS"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"input_process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "SentencePieceTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"sentencePiece_model": [
|
||||||
|
{
|
||||||
|
"name": "30k-clean.model",
|
||||||
|
"description": "The sentence piece model file."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"min_parser_version": "1.1.0"
|
||||||
|
}
|
20
mediapipe/tasks/testdata/metadata/bert_tokenizer_meta.json
vendored
Normal file
20
mediapipe/tasks/testdata/metadata/bert_tokenizer_meta.json
vendored
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "BertTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
30522
mediapipe/tasks/testdata/metadata/mobilebert_vocab.txt
vendored
Normal file
30522
mediapipe/tasks/testdata/metadata/mobilebert_vocab.txt
vendored
Normal file
File diff suppressed because it is too large
Load Diff
26
mediapipe/tasks/testdata/metadata/sentence_piece_tokenizer_meta.json
vendored
Normal file
26
mediapipe/tasks/testdata/metadata/sentence_piece_tokenizer_meta.json
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
{
|
||||||
|
"subgraph_metadata": [
|
||||||
|
{
|
||||||
|
"input_process_units": [
|
||||||
|
{
|
||||||
|
"options_type": "SentencePieceTokenizerOptions",
|
||||||
|
"options": {
|
||||||
|
"sentencePiece_model": [
|
||||||
|
{
|
||||||
|
"name": "sp.model",
|
||||||
|
"description": "The sentence piece model file."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"vocab_file": [
|
||||||
|
{
|
||||||
|
"name": "vocab.txt",
|
||||||
|
"description": "Vocabulary file to convert natural language words to embedding vectors. This file is optional during tokenization, while the sentence piece model is mandatory.",
|
||||||
|
"type": "VOCABULARY"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
14
mediapipe/tasks/testdata/vision/BUILD
vendored
14
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -143,20 +143,6 @@ filegroup(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Gestures related models. Visible to model_maker.
|
|
||||||
# TODO: Upload canned gesture model and gesture embedding model to GCS after Model Card approval
|
|
||||||
filegroup(
|
|
||||||
name = "test_gesture_models",
|
|
||||||
srcs = [
|
|
||||||
"hand_landmark_full.tflite",
|
|
||||||
"palm_detection_full.tflite",
|
|
||||||
],
|
|
||||||
visibility = [
|
|
||||||
"//mediapipe/model_maker:__subpackages__",
|
|
||||||
"//mediapipe/tasks:internal",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "test_protos",
|
name = "test_protos",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
|
|
@ -3,9 +3,22 @@
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_library")
|
||||||
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
|
load("@build_bazel_rules_nodejs//:index.bzl", "pkg_npm")
|
||||||
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
|
load("@npm//@bazel/rollup:index.bzl", "rollup_bundle")
|
||||||
|
load(
|
||||||
|
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||||
|
"mediapipe_files",
|
||||||
|
)
|
||||||
|
|
||||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
mediapipe_files(srcs = [
|
||||||
|
"wasm/audio_wasm_internal.js",
|
||||||
|
"wasm/audio_wasm_internal.wasm",
|
||||||
|
"wasm/text_wasm_internal.js",
|
||||||
|
"wasm/text_wasm_internal.wasm",
|
||||||
|
"wasm/vision_wasm_internal.js",
|
||||||
|
"wasm/vision_wasm_internal.wasm",
|
||||||
|
])
|
||||||
|
|
||||||
# Audio
|
# Audio
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
|
@ -28,15 +41,18 @@ rollup_bundle(
|
||||||
|
|
||||||
pkg_npm(
|
pkg_npm(
|
||||||
name = "audio_pkg",
|
name = "audio_pkg",
|
||||||
package_name = "__PACKAGE_NAME__",
|
package_name = "@mediapipe/tasks-__NAME__",
|
||||||
srcs = ["package.json"],
|
srcs = ["package.json"],
|
||||||
substitutions = {
|
substitutions = {
|
||||||
"__PACKAGE_NAME__": "@mediapipe/tasks-audio",
|
"__NAME__": "audio",
|
||||||
"__DESCRIPTION__": "MediaPipe Audio Tasks",
|
"__DESCRIPTION__": "MediaPipe Audio Tasks",
|
||||||
"__BUNDLE__": "audio_bundle.js",
|
|
||||||
},
|
},
|
||||||
tgz = "audio.tgz",
|
tgz = "audio.tgz",
|
||||||
deps = [":audio_bundle"],
|
deps = [
|
||||||
|
"wasm/audio_wasm_internal.js",
|
||||||
|
"wasm/audio_wasm_internal.wasm",
|
||||||
|
":audio_bundle",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Text
|
# Text
|
||||||
|
@ -61,15 +77,18 @@ rollup_bundle(
|
||||||
|
|
||||||
pkg_npm(
|
pkg_npm(
|
||||||
name = "text_pkg",
|
name = "text_pkg",
|
||||||
package_name = "__PACKAGE_NAME__",
|
package_name = "@mediapipe/tasks-__NAME__",
|
||||||
srcs = ["package.json"],
|
srcs = ["package.json"],
|
||||||
substitutions = {
|
substitutions = {
|
||||||
"__PACKAGE_NAME__": "@mediapipe/tasks-text",
|
"__NAME__": "text",
|
||||||
"__DESCRIPTION__": "MediaPipe Text Tasks",
|
"__DESCRIPTION__": "MediaPipe Text Tasks",
|
||||||
"__BUNDLE__": "text_bundle.js",
|
|
||||||
},
|
},
|
||||||
tgz = "text.tgz",
|
tgz = "text.tgz",
|
||||||
deps = [":text_bundle"],
|
deps = [
|
||||||
|
"wasm/text_wasm_internal.js",
|
||||||
|
"wasm/text_wasm_internal.wasm",
|
||||||
|
":text_bundle",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Vision
|
# Vision
|
||||||
|
@ -94,13 +113,16 @@ rollup_bundle(
|
||||||
|
|
||||||
pkg_npm(
|
pkg_npm(
|
||||||
name = "vision_pkg",
|
name = "vision_pkg",
|
||||||
package_name = "__PACKAGE_NAME__",
|
package_name = "@mediapipe/tasks-__NAME__",
|
||||||
srcs = ["package.json"],
|
srcs = ["package.json"],
|
||||||
substitutions = {
|
substitutions = {
|
||||||
"__PACKAGE_NAME__": "@mediapipe/tasks-vision",
|
"__NAME__": "vision",
|
||||||
"__DESCRIPTION__": "MediaPipe Vision Tasks",
|
"__DESCRIPTION__": "MediaPipe Vision Tasks",
|
||||||
"__BUNDLE__": "vision_bundle.js",
|
|
||||||
},
|
},
|
||||||
tgz = "vision.tgz",
|
tgz = "vision_pkg.tgz",
|
||||||
deps = [":vision_bundle"],
|
deps = [
|
||||||
|
"wasm/vision_wasm_internal.js",
|
||||||
|
"wasm/vision_wasm_internal.wasm",
|
||||||
|
":vision_bundle",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
|
|
@ -27,7 +27,7 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {AudioClassifierOptions} from './audio_classifier_options';
|
import {AudioClassifierOptions} from './audio_classifier_options';
|
||||||
import {Classifications} from './audio_classifier_result';
|
import {AudioClassifierResult} from './audio_classifier_result';
|
||||||
|
|
||||||
const MEDIAPIPE_GRAPH =
|
const MEDIAPIPE_GRAPH =
|
||||||
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
'mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph';
|
||||||
|
@ -38,14 +38,14 @@ const MEDIAPIPE_GRAPH =
|
||||||
// implementation
|
// implementation
|
||||||
const AUDIO_STREAM = 'input_audio';
|
const AUDIO_STREAM = 'input_audio';
|
||||||
const SAMPLE_RATE_STREAM = 'sample_rate';
|
const SAMPLE_RATE_STREAM = 'sample_rate';
|
||||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
const TIMESTAMPED_CLASSIFICATIONS_STREAM = 'timestamped_classifications';
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
// The OSS JS API does not support the builder pattern.
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
/** Performs audio classification. */
|
/** Performs audio classification. */
|
||||||
export class AudioClassifier extends TaskRunner {
|
export class AudioClassifier extends TaskRunner {
|
||||||
private classifications: Classifications[] = [];
|
private classificationResults: AudioClassifierResult[] = [];
|
||||||
private defaultSampleRate = 48000;
|
private defaultSampleRate = 48000;
|
||||||
private readonly options = new AudioClassifierGraphOptions();
|
private readonly options = new AudioClassifierGraphOptions();
|
||||||
|
|
||||||
|
@ -150,7 +150,8 @@ export class AudioClassifier extends TaskRunner {
|
||||||
* `48000` if no custom default was set.
|
* `48000` if no custom default was set.
|
||||||
* @return The classification result of the audio datas
|
* @return The classification result of the audio datas
|
||||||
*/
|
*/
|
||||||
classify(audioData: Float32Array, sampleRate?: number): Classifications[] {
|
classify(audioData: Float32Array, sampleRate?: number):
|
||||||
|
AudioClassifierResult[] {
|
||||||
sampleRate = sampleRate ?? this.defaultSampleRate;
|
sampleRate = sampleRate ?? this.defaultSampleRate;
|
||||||
|
|
||||||
// Configures the number of samples in the WASM layer. We re-configure the
|
// Configures the number of samples in the WASM layer. We re-configure the
|
||||||
|
@ -164,20 +165,22 @@ export class AudioClassifier extends TaskRunner {
|
||||||
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
this.addDoubleToStream(sampleRate, SAMPLE_RATE_STREAM, timestamp);
|
||||||
this.addAudioToStream(audioData, timestamp);
|
this.addAudioToStream(audioData, timestamp);
|
||||||
|
|
||||||
this.classifications = [];
|
this.classificationResults = [];
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return [...this.classifications];
|
return [...this.classificationResults];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Internal function for converting raw data into a classification, and
|
* Internal function for converting raw data into classification results, and
|
||||||
* adding it to our classfications list.
|
* adding them to our classfication results list.
|
||||||
**/
|
**/
|
||||||
private addJsAudioClassification(binaryProto: Uint8Array): void {
|
private addJsAudioClassificationResults(binaryProtos: Uint8Array[]): void {
|
||||||
const classificationResult =
|
binaryProtos.forEach(binaryProto => {
|
||||||
ClassificationResult.deserializeBinary(binaryProto);
|
const classificationResult =
|
||||||
this.classifications.push(
|
ClassificationResult.deserializeBinary(binaryProto);
|
||||||
...convertFromClassificationResultProto(classificationResult));
|
this.classificationResults.push(
|
||||||
|
convertFromClassificationResultProto(classificationResult));
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
@ -185,7 +188,7 @@ export class AudioClassifier extends TaskRunner {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(AUDIO_STREAM);
|
graphConfig.addInputStream(AUDIO_STREAM);
|
||||||
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
graphConfig.addInputStream(SAMPLE_RATE_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
graphConfig.addOutputStream(TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -198,14 +201,15 @@ export class AudioClassifier extends TaskRunner {
|
||||||
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
classifierNode.addInputStream('AUDIO:' + AUDIO_STREAM);
|
||||||
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
classifierNode.addInputStream('SAMPLE_RATE:' + SAMPLE_RATE_STREAM);
|
||||||
classifierNode.addOutputStream(
|
classifierNode.addOutputStream(
|
||||||
'CLASSIFICATIONS:' + CLASSIFICATION_RESULT_STREAM);
|
'TIMESTAMPED_CLASSIFICATIONS:' + TIMESTAMPED_CLASSIFICATIONS_STREAM);
|
||||||
classifierNode.setOptions(calculatorOptions);
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(classifierNode);
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
this.attachProtoVectorListener(
|
||||||
this.addJsAudioClassification(binaryProto);
|
TIMESTAMPED_CLASSIFICATIONS_STREAM, binaryProtos => {
|
||||||
});
|
this.addJsAudioClassificationResults(binaryProtos);
|
||||||
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
|
|
@ -15,4 +15,4 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
export {ClassificationResult as AudioClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
|
@ -10,8 +10,8 @@ mediapipe_ts_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "classifications",
|
name = "classification_result",
|
||||||
srcs = ["classifications.d.ts"],
|
srcs = ["classification_result.d.ts"],
|
||||||
deps = [":category"],
|
deps = [":category"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,27 +16,14 @@
|
||||||
|
|
||||||
import {Category} from '../../../../tasks/web/components/containers/category';
|
import {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
|
|
||||||
/** List of predicted categories with an optional timestamp. */
|
/** Classification results for a given classifier head. */
|
||||||
export interface ClassificationEntry {
|
export interface Classifications {
|
||||||
/**
|
/**
|
||||||
* The array of predicted categories, usually sorted by descending scores,
|
* The array of predicted categories, usually sorted by descending scores,
|
||||||
* e.g., from high to low probability.
|
* e.g., from high to low probability.
|
||||||
*/
|
*/
|
||||||
categories: Category[];
|
categories: Category[];
|
||||||
|
|
||||||
/**
|
|
||||||
* The optional timestamp (in milliseconds) associated to the classification
|
|
||||||
* entry. This is useful for time series use cases, e.g., audio
|
|
||||||
* classification.
|
|
||||||
*/
|
|
||||||
timestampMs?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Classifications for a given classifier head. */
|
|
||||||
export interface Classifications {
|
|
||||||
/** A list of classification entries. */
|
|
||||||
entries: ClassificationEntry[];
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The index of the classifier head these categories refer to. This is
|
* The index of the classifier head these categories refer to. This is
|
||||||
* useful for multi-head models.
|
* useful for multi-head models.
|
||||||
|
@ -45,7 +32,24 @@ export interface Classifications {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The name of the classifier head, which is the corresponding tensor
|
* The name of the classifier head, which is the corresponding tensor
|
||||||
* metadata name.
|
* metadata name. Defaults to an empty string if there is no such metadata.
|
||||||
*/
|
*/
|
||||||
headName: string;
|
headName: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Classification results of a model. */
|
||||||
|
export interface ClassificationResult {
|
||||||
|
/** The classification results for each head of the model. */
|
||||||
|
classifications: Classifications[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||||
|
* corresponding to these results.
|
||||||
|
*
|
||||||
|
* This is only used for classification on time series (e.g. audio
|
||||||
|
* classification). In these use cases, the amount of data to process might
|
||||||
|
* exceed the maximum size that the model can process: to solve this, the
|
||||||
|
* input data is split into multiple chunks starting at different timestamps.
|
||||||
|
*/
|
||||||
|
timestampMs?: number;
|
||||||
|
}
|
|
@ -17,8 +17,9 @@ mediapipe_ts_library(
|
||||||
name = "classifier_result",
|
name = "classifier_result",
|
||||||
srcs = ["classifier_result.ts"],
|
srcs = ["classifier_result.ts"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -14,48 +14,46 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import {ClassificationEntry as ClassificationEntryProto, ClassificationResult} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
import {ClassificationResult as ClassificationResultProto, Classifications as ClassificationsProto} from '../../../../tasks/cc/components/containers/proto/classifications_pb';
|
||||||
import {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
import {ClassificationResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
||||||
const DEFAULT_INDEX = -1;
|
const DEFAULT_INDEX = -1;
|
||||||
const DEFAULT_SCORE = 0.0;
|
const DEFAULT_SCORE = 0.0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts a ClassificationEntry proto to the ClassificationEntry result
|
* Converts a Classifications proto to a Classifications object.
|
||||||
* type.
|
|
||||||
*/
|
*/
|
||||||
function convertFromClassificationEntryProto(source: ClassificationEntryProto):
|
function convertFromClassificationsProto(source: ClassificationsProto):
|
||||||
ClassificationEntry {
|
Classifications {
|
||||||
const categories = source.getCategoriesList().map(category => {
|
const categories =
|
||||||
return {
|
source.getClassificationList()?.getClassificationList().map(
|
||||||
index: category.getIndex() ?? DEFAULT_INDEX,
|
classification => {
|
||||||
score: category.getScore() ?? DEFAULT_SCORE,
|
return {
|
||||||
displayName: category.getDisplayName() ?? '',
|
index: classification.getIndex() ?? DEFAULT_INDEX,
|
||||||
categoryName: category.getCategoryName() ?? '',
|
score: classification.getScore() ?? DEFAULT_SCORE,
|
||||||
};
|
categoryName: classification.getLabel() ?? '',
|
||||||
});
|
displayName: classification.getDisplayName() ?? '',
|
||||||
|
};
|
||||||
|
}) ??
|
||||||
|
[];
|
||||||
return {
|
return {
|
||||||
categories,
|
categories,
|
||||||
timestampMs: source.getTimestampMs(),
|
headIndex: source.getHeadIndex() ?? DEFAULT_INDEX,
|
||||||
|
headName: source.getHeadName() ?? '',
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Converts a ClassificationResult proto to a list of classifications.
|
* Converts a ClassificationResult proto to a ClassificationResult object.
|
||||||
*/
|
*/
|
||||||
export function convertFromClassificationResultProto(
|
export function convertFromClassificationResultProto(
|
||||||
classificationResult: ClassificationResult) : Classifications[] {
|
source: ClassificationResultProto): ClassificationResult {
|
||||||
const result: Classifications[] = [];
|
const result: ClassificationResult = {
|
||||||
for (const classificationsProto of
|
classifications: source.getClassificationsList().map(
|
||||||
classificationResult.getClassificationsList()) {
|
classififications => convertFromClassificationsProto(classififications))
|
||||||
const classifications: Classifications = {
|
};
|
||||||
entries: classificationsProto.getEntriesList().map(
|
if (source.hasTimestampMs()) {
|
||||||
entry => convertFromClassificationEntryProto(entry)),
|
result.timestampMs = source.getTimestampMs();
|
||||||
headIndex: classificationsProto.getHeadIndex() ?? DEFAULT_INDEX,
|
|
||||||
headName: classificationsProto.getHeadName() ?? '',
|
|
||||||
};
|
|
||||||
result.push(classifications);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
{
|
{
|
||||||
"name": "__PACKAGE_NAME__",
|
"name": "@mediapipe/tasks-__NAME__",
|
||||||
"version": "__VERSION__",
|
"version": "__VERSION__",
|
||||||
"description": "__DESCRIPTION__",
|
"description": "__DESCRIPTION__",
|
||||||
"main": "__BUNDLE__",
|
"main": "__NAME__bundle.js",
|
||||||
"module": "__BUNDLE__",
|
"module": "__NAME__bundle.js",
|
||||||
|
"exports": {
|
||||||
|
".": "./__NAME__bundle.js",
|
||||||
|
"./loader": "./wasm/__NAME__wasm_internal.js",
|
||||||
|
"./wasm": "./wasm/__NAME__wasm_internal.wasm"
|
||||||
|
},
|
||||||
"author": "mediapipe@google.com",
|
"author": "mediapipe@google.com",
|
||||||
"license": "Apache-2.0",
|
"license": "Apache-2.0",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
|
|
|
@ -22,7 +22,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
|
|
@ -27,10 +27,10 @@ import {createMediaPipeLib, FileLocator} from '../../../../web/graph_runner/wasm
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {TextClassifierOptions} from './text_classifier_options';
|
import {TextClassifierOptions} from './text_classifier_options';
|
||||||
import {Classifications} from './text_classifier_result';
|
import {TextClassifierResult} from './text_classifier_result';
|
||||||
|
|
||||||
const INPUT_STREAM = 'text_in';
|
const INPUT_STREAM = 'text_in';
|
||||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result_out';
|
const CLASSIFICATIONS_STREAM = 'classifications_out';
|
||||||
const TEXT_CLASSIFIER_GRAPH =
|
const TEXT_CLASSIFIER_GRAPH =
|
||||||
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
'mediapipe.tasks.text.text_classifier.TextClassifierGraph';
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ const TEXT_CLASSIFIER_GRAPH =
|
||||||
|
|
||||||
/** Performs Natural Language classification. */
|
/** Performs Natural Language classification. */
|
||||||
export class TextClassifier extends TaskRunner {
|
export class TextClassifier extends TaskRunner {
|
||||||
private classifications: Classifications[] = [];
|
private classificationResult: TextClassifierResult = {classifications: []};
|
||||||
private readonly options = new TextClassifierGraphOptions();
|
private readonly options = new TextClassifierGraphOptions();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -129,30 +129,20 @@ export class TextClassifier extends TaskRunner {
|
||||||
* @param text The text to process.
|
* @param text The text to process.
|
||||||
* @return The classification result of the text
|
* @return The classification result of the text
|
||||||
*/
|
*/
|
||||||
classify(text: string): Classifications[] {
|
classify(text: string): TextClassifierResult {
|
||||||
// Get classification classes by running our MediaPipe graph.
|
// Get classification result by running our MediaPipe graph.
|
||||||
this.classifications = [];
|
this.classificationResult = {classifications: []};
|
||||||
this.addStringToStream(
|
this.addStringToStream(
|
||||||
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
text, INPUT_STREAM, /* timestamp= */ performance.now());
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return [...this.classifications];
|
return this.classificationResult;
|
||||||
}
|
|
||||||
|
|
||||||
// Internal function for converting raw data into a classification, and
|
|
||||||
// adding it to our classifications list.
|
|
||||||
private addJsTextClassification(binaryProto: Uint8Array): void {
|
|
||||||
const classificationResult =
|
|
||||||
ClassificationResult.deserializeBinary(binaryProto);
|
|
||||||
console.log(classificationResult.toObject());
|
|
||||||
this.classifications.push(
|
|
||||||
...convertFromClassificationResultProto(classificationResult));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
private refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -161,14 +151,14 @@ export class TextClassifier extends TaskRunner {
|
||||||
const classifierNode = new CalculatorGraphConfig.Node();
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
classifierNode.setCalculator(TEXT_CLASSIFIER_GRAPH);
|
||||||
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
classifierNode.addInputStream('TEXT:' + INPUT_STREAM);
|
||||||
classifierNode.addOutputStream(
|
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
|
||||||
classifierNode.setOptions(calculatorOptions);
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(classifierNode);
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||||
this.addJsTextClassification(binaryProto);
|
this.classificationResult = convertFromClassificationResultProto(
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto));
|
||||||
});
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
|
|
@ -15,4 +15,4 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
export {ClassificationResult as TextClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
|
@ -21,7 +21,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_jspb_proto",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/components/containers:category",
|
"//mediapipe/tasks/web/components/containers:category",
|
||||||
"//mediapipe/tasks/web/components/containers:classifications",
|
"//mediapipe/tasks/web/components/containers:classification_result",
|
||||||
"//mediapipe/tasks/web/components/processors:base_options",
|
"//mediapipe/tasks/web/components/processors:base_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_options",
|
"//mediapipe/tasks/web/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/web/components/processors:classifier_result",
|
"//mediapipe/tasks/web/components/processors:classifier_result",
|
||||||
|
|
|
@ -27,12 +27,12 @@ import {createMediaPipeLib, FileLocator, ImageSource} from '../../../../web/grap
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {ImageClassifierOptions} from './image_classifier_options';
|
import {ImageClassifierOptions} from './image_classifier_options';
|
||||||
import {Classifications} from './image_classifier_result';
|
import {ImageClassifierResult} from './image_classifier_result';
|
||||||
|
|
||||||
const IMAGE_CLASSIFIER_GRAPH =
|
const IMAGE_CLASSIFIER_GRAPH =
|
||||||
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
'mediapipe.tasks.vision.image_classifier.ImageClassifierGraph';
|
||||||
const INPUT_STREAM = 'input_image';
|
const INPUT_STREAM = 'input_image';
|
||||||
const CLASSIFICATION_RESULT_STREAM = 'classification_result';
|
const CLASSIFICATIONS_STREAM = 'classifications';
|
||||||
|
|
||||||
export {ImageSource}; // Used in the public API
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
/** Performs classification on images. */
|
/** Performs classification on images. */
|
||||||
export class ImageClassifier extends TaskRunner {
|
export class ImageClassifier extends TaskRunner {
|
||||||
private classifications: Classifications[] = [];
|
private classificationResult: ImageClassifierResult = {classifications: []};
|
||||||
private readonly options = new ImageClassifierGraphOptions();
|
private readonly options = new ImageClassifierGraphOptions();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -133,31 +133,21 @@ export class ImageClassifier extends TaskRunner {
|
||||||
* provided, defaults to `performance.now()`.
|
* provided, defaults to `performance.now()`.
|
||||||
* @return The classification result of the image
|
* @return The classification result of the image
|
||||||
*/
|
*/
|
||||||
classify(imageSource: ImageSource, timestamp?: number): Classifications[] {
|
classify(imageSource: ImageSource, timestamp?: number):
|
||||||
// Get classification classes by running our MediaPipe graph.
|
ImageClassifierResult {
|
||||||
this.classifications = [];
|
// Get classification result by running our MediaPipe graph.
|
||||||
|
this.classificationResult = {classifications: []};
|
||||||
this.addGpuBufferAsImageToStream(
|
this.addGpuBufferAsImageToStream(
|
||||||
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
imageSource, INPUT_STREAM, timestamp ?? performance.now());
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return [...this.classifications];
|
return this.classificationResult;
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Internal function for converting raw data into a classification, and
|
|
||||||
* adding it to our classfications list.
|
|
||||||
**/
|
|
||||||
private addJsImageClassification(binaryProto: Uint8Array): void {
|
|
||||||
const classificationResult =
|
|
||||||
ClassificationResult.deserializeBinary(binaryProto);
|
|
||||||
this.classifications.push(
|
|
||||||
...convertFromClassificationResultProto(classificationResult));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
private refreshGraph(): void {
|
private refreshGraph(): void {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(INPUT_STREAM);
|
graphConfig.addInputStream(INPUT_STREAM);
|
||||||
graphConfig.addOutputStream(CLASSIFICATION_RESULT_STREAM);
|
graphConfig.addOutputStream(CLASSIFICATIONS_STREAM);
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -168,14 +158,14 @@ export class ImageClassifier extends TaskRunner {
|
||||||
const classifierNode = new CalculatorGraphConfig.Node();
|
const classifierNode = new CalculatorGraphConfig.Node();
|
||||||
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
|
classifierNode.setCalculator(IMAGE_CLASSIFIER_GRAPH);
|
||||||
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
classifierNode.addInputStream('IMAGE:' + INPUT_STREAM);
|
||||||
classifierNode.addOutputStream(
|
classifierNode.addOutputStream('CLASSIFICATIONS:' + CLASSIFICATIONS_STREAM);
|
||||||
'CLASSIFICATION_RESULT:' + CLASSIFICATION_RESULT_STREAM);
|
|
||||||
classifierNode.setOptions(calculatorOptions);
|
classifierNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(classifierNode);
|
graphConfig.addNode(classifierNode);
|
||||||
|
|
||||||
this.attachProtoListener(CLASSIFICATION_RESULT_STREAM, binaryProto => {
|
this.attachProtoListener(CLASSIFICATIONS_STREAM, binaryProto => {
|
||||||
this.addJsImageClassification(binaryProto);
|
this.classificationResult = convertFromClassificationResultProto(
|
||||||
|
ClassificationResult.deserializeBinary(binaryProto));
|
||||||
});
|
});
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
|
|
|
@ -15,4 +15,4 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
export {Category} from '../../../../tasks/web/components/containers/category';
|
export {Category} from '../../../../tasks/web/components/containers/category';
|
||||||
export {ClassificationEntry, Classifications} from '../../../../tasks/web/components/containers/classifications';
|
export {ClassificationResult as ImageClassifierResult, Classifications} from '../../../../tasks/web/components/containers/classification_result';
|
||||||
|
|
38
third_party/external_files.bzl
vendored
38
third_party/external_files.bzl
vendored
|
@ -28,12 +28,36 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/associated_file_meta.json?generation=1665422792304395"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_bert_text_classifier_no_metadata_tflite",
|
||||||
|
sha256 = "9b4554f6e28a72a3f40511964eed1ccf4e74cc074f81543cacca4faf169a173e",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_no_metadata.tflite?generation=1667948360250899"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_bert_text_classifier_tflite",
|
name = "com_google_mediapipe_bert_text_classifier_tflite",
|
||||||
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
|
sha256 = "1e5a550c09bff0a13e61858bcfac7654d7fcc6d42106b4f15e11117695069600",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier.tflite?generation=1666144699858747"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_bert_text_classifier_with_bert_tokenizer_json",
|
||||||
|
sha256 = "49f148a13a4e3b486b1d3c2400e46e5ebd0d375674c0154278b835760e873a95",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_bert_tokenizer.json?generation=1667948363241334"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_bert_text_classifier_with_sentence_piece_json",
|
||||||
|
sha256 = "113091f3892691de57e379387256b2ce0cc18a1b5185af866220a46da8221f26",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_text_classifier_with_sentence_piece.json?generation=1667948366009530"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_bert_tokenizer_meta_json",
|
||||||
|
sha256 = "116d70c7c3ef413a8bff54ab758f9ed3d6e51fdc5621d8c920ad2f0035831804",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/bert_tokenizer_meta.json?generation=1667948368809108"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_bounding_box_tensor_meta_json",
|
name = "com_google_mediapipe_bounding_box_tensor_meta_json",
|
||||||
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
|
sha256 = "cc019cee86529955a24a3d43ca3d778fa366bcb90d67c8eaf55696789833841a",
|
||||||
|
@ -403,7 +427,7 @@ def external_files():
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_labels_txt",
|
name = "com_google_mediapipe_labels_txt",
|
||||||
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
|
sha256 = "536feacc519de3d418de26b2effb4d75694a8c4c0063e36499a46fa8061e2da9",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667888034706429"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/labels.txt?generation=1667892497527642"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -553,13 +577,13 @@ def external_files():
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_movie_review_json",
|
name = "com_google_mediapipe_movie_review_json",
|
||||||
sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e",
|
sha256 = "c09b88af05844cad5133b49744fed3a0bd514d4a1c75b9d2f23e9a40bd7bc04e",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667888039053188"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review.json?generation=1667892501695336"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_movie_review_labels_txt",
|
name = "com_google_mediapipe_movie_review_labels_txt",
|
||||||
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
|
sha256 = "4b9b26392f765e7a872372131cd4cee8ad7c02e496b5a1228279619b138c4b7a",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667888041670721"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/movie_review_labels.txt?generation=1667892504334882"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -703,7 +727,7 @@ def external_files():
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_regex_vocab_txt",
|
name = "com_google_mediapipe_regex_vocab_txt",
|
||||||
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
|
sha256 = "b1134b10927a53ce4224bbc30ccf075c9969c94ebf40c368966d1dcf445ca923",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667888047885461"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/regex_vocab.txt?generation=1667892507770551"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -790,6 +814,12 @@ def external_files():
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_sentence_piece_tokenizer_meta_json",
|
||||||
|
sha256 = "416bfe231710502e4a93e1b1950c0c6e5db49cffb256d241ef3d3f2d0d57718b",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/sentence_piece_tokenizer_meta.json?generation=1667948375508564"],
|
||||||
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_speech_16000_hz_mono_wav",
|
name = "com_google_mediapipe_speech_16000_hz_mono_wav",
|
||||||
sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6",
|
sha256 = "71caf50b8757d6ab9cad5eae4d36669d3c20c225a51660afd7fe0dc44cdb74f6",
|
||||||
|
|
47
third_party/wasm_files.bzl
vendored
Normal file
47
third_party/wasm_files.bzl
vendored
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
"""
|
||||||
|
WASM dependencies for MediaPipe.
|
||||||
|
|
||||||
|
This file is auto-generated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file")
|
||||||
|
|
||||||
|
# buildifier: disable=unnamed-macro
|
||||||
|
def wasm_files():
|
||||||
|
"""WASM dependencies for MediaPipe."""
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
|
||||||
|
sha256 = "9419766229f24790388805d891af907cf11fe8e2cdacabcf016feb054b720c82",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1667934266184984"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_wasm_text_wasm_internal_js",
|
||||||
|
sha256 = "39d9445ab3b90f625a3332251fe82e59b40cd0501a5657475f3b115b7c6122c8",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1667934268229056"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
|
||||||
|
sha256 = "b43c7078fe5da72990394af4fefd798bd844b4ac47849a49067bd68c3c910a3d",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1667934270239845"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
|
||||||
|
sha256 = "9f2abe2a51d1ebc854859f620759cec1cc643773f3748d0d19e0868578c3d746",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1667934272818542"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
|
||||||
|
sha256 = "8334caec5fb10cd1f936f6ee41f8853771c7bf3a421f5c15c39ee41aa503ca54",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1667934275451198"],
|
||||||
|
)
|
||||||
|
|
||||||
|
http_file(
|
||||||
|
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
|
||||||
|
sha256 = "b996eaa324da151359ad8e16edad27d9768505f1fd073625bc50dbb0f252e098",
|
||||||
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1667934277855507"],
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user