Merge branch 'google:master' into image-segmenter-python-impl

This commit is contained in:
Kinar R 2022-10-22 02:04:57 +05:30 committed by GitHub
commit 1748663a5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
97 changed files with 3826 additions and 350 deletions

View File

@ -1410,3 +1410,45 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "bypass_calculator_proto",
srcs = ["bypass_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bypass_calculator",
srcs = ["bypass_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":bypass_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "bypass_calculator_test",
srcs = ["bypass_calculator_test.cc"],
deps = [
":bypass_calculator",
":pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,161 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/bypass_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
using mediapipe::BypassCalculatorOptions;
// Defines a "bypass" channel to use in place of a disabled feature subgraph.
// By default, all inputs are discarded and all outputs are ignored.
// Certain input streams can be passed to corresponding output streams
// by specifying them in "pass_input_stream" and "pass_output_stream" options.
// All output streams are updated with timestamp bounds indicating completed
// output.
//
// Note that this calculator is designed for use as a contained_node in a
// SwitchContainer. For this reason, any input and output tags are accepted,
// and stream semantics are specified through BypassCalculatorOptions.
//
// Example config:
// node {
// calculator: "BypassCalculator"
// input_stream: "APPEARANCES:appearances_post_facenet"
// input_stream: "VIDEO:video_frame"
// input_stream: "FEATURE_CONFIG:feature_config"
// input_stream: "ENABLE:gaze_enabled"
// output_stream: "APPEARANCES:analyzed_appearances"
// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
// node_options: {
// [type.googleapis.com/mediapipe.BypassCalculatorOptions] {
// pass_input_stream: "APPEARANCES"
// pass_output_stream: "APPEARANCES"
// }
// }
// }
//
class BypassCalculator : public Node {
public:
static constexpr mediapipe::api2::Input<int>::Optional kNotNeeded{"N_N_"};
MEDIAPIPE_NODE_CONTRACT(kNotNeeded);
using IdMap = std::map<CollectionItemId, CollectionItemId>;
// Returns the map of passthrough input and output stream ids.
static absl::StatusOr<IdMap> GetPassMap(
const BypassCalculatorOptions& options, const tool::TagMap& input_map,
const tool::TagMap& output_map) {
IdMap result;
auto& input_streams = options.pass_input_stream();
auto& output_streams = options.pass_output_stream();
int size = std::min(input_streams.size(), output_streams.size());
for (int i = 0; i < size; ++i) {
std::pair<std::string, int> in_tag, out_tag;
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i),
&in_tag.first, &in_tag.second));
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i),
&out_tag.first, &out_tag.second));
auto input_id = input_map.GetId(in_tag.first, in_tag.second);
auto output_id = output_map.GetId(out_tag.first, out_tag.second);
result[input_id] = output_id;
}
return result;
}
// Identifies all specified streams as "Any" packet type.
// Identifies passthrough streams as "Same" packet type.
static absl::Status UpdateContract(CalculatorContract* cc) {
auto options = cc->Options<BypassCalculatorOptions>();
RET_CHECK_EQ(options.pass_input_stream().size(),
options.pass_output_stream().size());
ASSIGN_OR_RETURN(
auto pass_streams,
GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap()));
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams) {
pass_out.insert(entry.second);
cc->Inputs().Get(entry.first).SetAny();
cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first));
}
for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) {
if (pass_streams.count(id) == 0) {
cc->Inputs().Get(id).SetAny();
}
}
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetAny();
}
}
return absl::OkStatus();
}
// Saves the map of passthrough input and output stream ids.
absl::Status Open(CalculatorContext* cc) override {
auto options = cc->Options<BypassCalculatorOptions>();
ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(),
*cc->Outputs().TagMap()));
return absl::OkStatus();
}
// Copies packets between passthrough input and output streams.
// Updates timestamp bounds on all output streams.
absl::Status Process(CalculatorContext* cc) override {
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams_) {
pass_out.insert(entry.second);
auto& packet = cc->Inputs().Get(entry.first).Value();
if (packet.Timestamp() == cc->InputTimestamp()) {
cc->Outputs().Get(entry.first).AddPacket(packet);
}
}
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetNextTimestampBound(
std::max(cc->Outputs().Get(id).NextTimestampBound(), bound));
}
}
return absl::OkStatus();
}
// Close all output streams.
absl::Status Close(CalculatorContext* cc) override {
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).Close();
}
return absl::OkStatus();
}
private:
IdMap pass_streams_;
};
MEDIAPIPE_REGISTER_NODE(BypassCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BypassCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional BypassCalculatorOptions ext = 481259677;
}
// Names an input stream or streams to pass through, by "TAG:index".
repeated string pass_input_stream = 1;
// Names an output stream or streams to pass through, by "TAG:index".
repeated string pass_output_stream = 2;
}

View File

@ -0,0 +1,302 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// A graph with using a BypassCalculator to pass through and ignore
// most of its inputs and outputs.
constexpr char kTestGraphConfig1[] = R"pb(
type: "AppearancesPassThroughSubgraph"
input_stream: "APPEARANCES:appearances"
input_stream: "VIDEO:video_frame"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "APPEARANCES:passthrough_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output"
node {
calculator: "BypassCalculator"
input_stream: "PASS:appearances"
input_stream: "TRUNCATE:0:video_frame"
input_stream: "TRUNCATE:1:feature_config"
output_stream: "PASS:passthrough_appearances"
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "PASS"
pass_output_stream: "PASS"
}
}
}
)pb";
// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig2[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
}
}
}
)pb";
// A graph with using BypassCalculator as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig3[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: {
calculator: "BypassCalculator"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "APPEARANCES"
pass_output_stream: "APPEARANCES"
}
}
}
}
}
}
)pb";
// A graph with using BypassCalculator as a disabled-gate
// for input frames and appearances.
constexpr char kTestGraphConfig4[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "ENABLE:gaze_enabled"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "VIDEO:video_frame_out"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEATURE_CONFIG:feature_config_out"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "BypassCalculator" }
contained_node: { calculator: "PassThroughCalculator" }
}
}
}
)pb";
// Reports packet timestamp and string contents, or "<empty>"".
std::string DebugString(Packet p) {
return absl::StrCat(p.Timestamp().DebugString(), ":",
p.IsEmpty() ? "<empty>" : p.Get<std::string>());
}
// Shows a bypass subgraph that passes through one stream.
TEST(BypassCalculatorTest, SubgraphChannel) {
CalculatorGraphConfig config_1 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig1);
CalculatorGraphConfig config_2 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig2);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that passes through one stream.
TEST(BypassCalculatorTest, CalculatorChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig3);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that discards all inputs when ENABLED is false.
TEST(BypassCalculatorTest, GatedChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig4);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> video_frame;
MP_ASSERT_OK(graph.ObserveOutputStream(
"video_frame_out",
[&](const Packet& p) {
video_frame.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
// Close the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(false).At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 200.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Only timestamps arrive from the BypassCalculator.
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:<empty>"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>"));
// Open the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(true).At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 300.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f2").At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Packets arrive from the PassThroughCalculator.
EXPECT_THAT(analyzed_appearances,
testing::ElementsAre("200:<empty>", "300:a2"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>", "300:v2"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -16,6 +16,9 @@ syntax = "proto2";
package mediapipe; package mediapipe;
option java_package = "com.google.mediapipe.calculator.proto";
option java_outer_classname = "RotationModeProto";
// Counterclockwise rotation. // Counterclockwise rotation.
message RotationMode { message RotationMode {
enum Mode { enum Mode {

View File

@ -253,6 +253,60 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "text_to_tensor_calculator_test",
srcs = ["text_to_tensor_calculator_test.cc"],
deps = [
":text_to_tensor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:options_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "universal_sentence_encoder_preprocessor_calculator",
srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],

View File

@ -270,10 +270,10 @@ class GlProcessor : public ImageToTensorConverter {
Tensor& output_tensor) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat( return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ", "Unsupported format: ", static_cast<uint32_t>(input.format())));
static_cast<uint32_t>(input.format())));
} }
const auto& output_shape = output_tensor.shape(); const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
@ -281,12 +281,13 @@ class GlProcessor : public ImageToTensorConverter {
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
[this, &output_tensor, &input, &roi, &output_shape, range_min, [this, &output_tensor, &input, &roi, &output_shape, range_min,
range_max, tensor_buffer_offset]() -> absl::Status { range_max, tensor_buffer_offset]() -> absl::Status {
constexpr int kRgbaNumChannels = 4; const int input_num_channels = input.channels();
auto source_texture = gl_helper_.CreateSourceTexture(input); auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture( tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(), GL_RGBA, GL_TEXTURE_2D, source_texture.name(),
input_num_channels == 4 ? GL_RGB : GL_RGBA,
source_texture.width() * source_texture.height() * source_texture.width() * source_texture.height() *
kRgbaNumChannels * sizeof(uint8_t), input_num_channels * sizeof(uint8_t),
/*layer=*/0, /*layer=*/0,
/*owned=*/false); /*owned=*/false);

View File

@ -174,10 +174,10 @@ class GlProcessor : public ImageToTensorConverter {
Tensor& output_tensor) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat( return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ", "Unsupported format: ", static_cast<uint32_t>(input.format())));
static_cast<uint32_t>(input.format())));
} }
// TODO: support tensor_buffer_offset > 0 scenario. // TODO: support tensor_buffer_offset > 0 scenario.
RET_CHECK_EQ(tensor_buffer_offset, 0) RET_CHECK_EQ(tensor_buffer_offset, 0)

View File

@ -0,0 +1,130 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/sink.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kMaxSeqLen = 256;
constexpr char kTestModelPath[] =
"mediapipe/tasks/testdata/text/"
"test_model_text_classifier_with_regex_tokenizer.tflite";
absl::StatusOr<std::vector<int>> RunRegexPreprocessorCalculator(
absl::string_view text) {
auto graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "RegexPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
options {
[mediapipe.RegexPreprocessorCalculatorOptions.ext] {
max_seq_len: $0
}
}
}
)pb",
kMaxSeqLen));
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer = tasks::core::LoadBinaryContent(kTestModelPath);
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected $1", tensor_vec.size(), 1));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kInt32) {
return absl::InvalidArgumentError("Expected tensor element type kInt32");
}
auto* buffer = tensor_vec[0].GetCpuReadView().buffer<int>();
std::vector<int> result(buffer, buffer + kMaxSeqLen);
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
return result;
}
TEST(RegexPreprocessorCalculatorTest, TextClassifierModel) {
MP_ASSERT_OK_AND_ASSIGN(
std::vector<int> processed_tensor_values,
RunRegexPreprocessorCalculator("This is the best movie Ive seen in "
"recent years. Strongly recommend it!"));
static const int expected_result[kMaxSeqLen] = {
1, 2, 9, 4, 118, 20, 2, 2, 110, 11, 1136, 153, 2, 386, 12};
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
TEST(RegexPreprocessorCalculatorTest, LongInput) {
std::stringstream long_input;
long_input << "This is the best";
for (int i = 0; i < kMaxSeqLen; ++i) {
long_input << " best";
}
long_input << "movie Ive seen in recent years. Strongly recommend it!";
MP_ASSERT_OK_AND_ASSIGN(std::vector<int> processed_tensor_values,
RunRegexPreprocessorCalculator(long_input.str()));
std::vector<int> expected_result = {1, 2, 9, 4, 118};
// "best" id
expected_result.resize(kMaxSeqLen, 118);
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
} // namespace
} // namespace mediapipe

View File

@ -67,9 +67,7 @@ absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
"tensor_vec has size $0, expected 1", tensor_vec.size())); "tensor_vec has size $0, expected 1", tensor_vec.size()));
} }
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute( return absl::InvalidArgumentError("Expected tensor element type kChar");
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
} }
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>(); const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
return std::string(buffer, text.length()); return std::string(buffer, text.length());

View File

@ -88,9 +88,7 @@ RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
kNumInputTensorsForUniversalSentenceEncoder)); kNumInputTensorsForUniversalSentenceEncoder));
} }
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute( return absl::InvalidArgumentError("Expected tensor element type kChar");
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
} }
std::vector<std::string> results; std::vector<std::string> results;
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) { for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {

View File

@ -91,9 +91,9 @@ class TfLiteModelCalculator : public CalculatorBase {
tflite::DefaultErrorReporter()); tflite::DefaultErrorReporter());
model = tflite::FlatBufferModel::BuildFromAllocation( model = tflite::FlatBufferModel::BuildFromAllocation(
std::move(model_allocation), tflite::DefaultErrorReporter()); std::move(model_allocation), tflite::DefaultErrorReporter());
#elif #else
return absl::FailedPreconditionError( return absl::FailedPreconditionError(
"Loading by file descriptor is not supported on this platform.") "Loading by file descriptor is not supported on this platform.");
#endif #endif
} }

View File

@ -378,8 +378,11 @@ cc_library(
], ],
}), }),
deps = [ deps = [
":gl_texture_buffer",
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage", ":gpu_buffer_storage",
":image_frame_view",
"//mediapipe/framework/formats:image_frame",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],
) )

View File

@ -78,7 +78,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) {
16, 16,
0}; 0};
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs]; pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_2_1];
} }
if (!pixel_format_) { if (!pixel_format_) {
// On several Forge machines, the default config fails. For now let's do // On several Forge machines, the default config fails. For now let's do

View File

@ -65,6 +65,7 @@ class GlTextureView {
friend class GpuBuffer; friend class GpuBuffer;
friend class GlTextureBuffer; friend class GlTextureBuffer;
friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageCvPixelBuffer;
friend class GpuBufferStorageAhwb;
GlTextureView(GlContext* context, GLenum target, GLuint name, int width, GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane, int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
DetachFn detach, DoneWritingFn done_writing) DetachFn detach, DoneWritingFn done_writing)

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h"
#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h"
#include "mediapipe/gpu/gpu_test_base.h" #include "mediapipe/gpu/gpu_test_base.h"
#include "stb_image.h" #include "stb_image.h"

View File

@ -23,13 +23,17 @@ from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset): class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models.""" """DataLoader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any): def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
super().__init__(dataset, size) super().__init__(dataset, size)
self.index_to_label = index_to_label self._index_by_label = index_by_label
@property @property
def num_classes(self: ds._DatasetT) -> int: def num_classes(self: ds._DatasetT) -> int:
return len(self.index_to_label) return len(self._index_by_label)
@property
def index_by_label(self: ds._DatasetT) -> Any:
return self._index_by_label
def split(self: ds._DatasetT, def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]: fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
@ -44,4 +48,4 @@ class ClassificationDataset(ds.Dataset):
Returns: Returns:
The splitted two sub datasets. The splitted two sub datasets.
""" """
return self._split(fraction, self.index_to_label) return self._split(fraction, self._index_by_label)

View File

@ -12,45 +12,59 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Tuple, TypeVar
# Dependency imports # Dependency imports
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset from mediapipe.model_maker.python.core.data import classification_dataset
_DatasetT = TypeVar(
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
class ClassificationDataLoaderTest(tf.test.TestCase):
class ClassificationDatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
class MagicClassificationDataLoader( class MagicClassificationDataset(
classification_dataset.ClassificationDataset): classification_dataset.ClassificationDataset):
"""A mock classification dataset class for testing purpose.
def __init__(self, dataset, size, index_to_label, value): Attributes:
super(MagicClassificationDataLoader, value: A value variable stored by the mock dataset class for testing.
self).__init__(dataset, size, index_to_label) """
def __init__(self, dataset: tf.data.Dataset, size: int,
index_by_label: Any, value: Any):
super().__init__(
dataset=dataset, size=size, index_by_label=index_by_label)
self.value = value self.value = value
def split(self, fraction): def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_to_label, self.value) return self._split(fraction, self.index_by_label, self.value)
# Some dummy inputs. # Some dummy inputs.
magic_value = 42 magic_value = 42
num_classes = 2 num_classes = 2
index_to_label = (False, True) index_by_label = (False, True)
# Create data loader from sample data. # Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataLoader(ds, len(ds), index_to_label, data = MagicClassificationDataset(
magic_value) dataset=ds,
size=len(ds),
index_by_label=index_by_label,
value=magic_value)
# Train/Test data split. # Train/Test data split.
fraction = .25 fraction = .25
train_data, test_data = data.split(fraction) train_data, test_data = data.split(fraction=fraction)
# `split` should return instances of child DataLoader. # `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataLoader) self.assertIsInstance(train_data, MagicClassificationDataset)
self.assertIsInstance(test_data, MagicClassificationDataLoader) self.assertIsInstance(test_data, MagicClassificationDataset)
# Make sure number of entries are right. # Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data)) self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
@ -59,7 +73,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
# Make sure attributes propagated correctly. # Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes) self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_to_label, index_to_label) self.assertEqual(test_data.index_by_label, index_by_label)
self.assertEqual(train_data.value, magic_value) self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value) self.assertEqual(test_data.value, magic_value)

View File

@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
class Classifier(custom_model.CustomModel): class Classifier(custom_model.CustomModel):
"""An abstract base class that represents a TensorFlow classifier.""" """An abstract base class that represents a TensorFlow classifier."""
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool, def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool,
full_train: bool): full_train: bool):
"""Initilizes a classifier with its specifications. """Initilizes a classifier with its specifications.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_to_label: A list that map from index to label class name. index_by_label: A list that map from index to label class name.
shuffle: Whether the dataset should be shuffled. shuffle: Whether the dataset should be shuffled.
full_train: If true, train the model end-to-end including the backbone full_train: If true, train the model end-to-end including the backbone
and the classification layers on top. Otherwise, only train the top and the classification layers on top. Otherwise, only train the top
classification layers. classification layers.
""" """
super(Classifier, self).__init__(model_spec, shuffle) super(Classifier, self).__init__(model_spec, shuffle)
self._index_to_label = index_to_label self._index_by_label = index_by_label
self._full_train = full_train self._full_train = full_train
self._num_classes = len(index_to_label) self._num_classes = len(index_by_label)
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
"""Evaluates the classifier with the provided evaluation dataset. """Evaluates the classifier with the provided evaluation dataset.
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
label_filepath = os.path.join(export_dir, label_filename) label_filepath = os.path.join(export_dir, label_filename)
tf.compat.v1.logging.info('Saving labels in %s', label_filepath) tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
with tf.io.gfile.GFile(label_filepath, 'w') as f: with tf.io.gfile.GFile(label_filepath, 'w') as f:
f.write('\n'.join(self._index_to_label)) f.write('\n'.join(self._index_by_label))

View File

@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(ClassifierTest, self).setUp() super(ClassifierTest, self).setUp()
index_to_label = ['cat', 'dog'] index_by_label = ['cat', 'dog']
self.model = MockClassifier( self.model = MockClassifier(
model_spec=None, model_spec=None,
index_to_label=index_to_label, index_by_label=index_by_label,
shuffle=False, shuffle=False,
full_train=False) full_train=False)
self.model.model = test_util.build_model(input_shape=[4], num_classes=2) self.model.model = test_util.build_model(input_shape=[4], num_classes=2)

View File

@ -21,8 +21,6 @@ import abc
import os import os
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
# Dependency imports
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.data import dataset
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
tflite_filepath = os.path.join(export_dir, tflite_filename) tflite_filepath = os.path.join(export_dir, tflite_filename)
# TODO: Populate metadata to the exported TFLite model. # TODO: Populate metadata to the exported TFLite model.
model_util.export_tflite( model_util.export_tflite(
self._model, model=self._model,
tflite_filepath, tflite_filepath=tflite_filepath,
quantization_config, quantization_config=quantization_config,
preprocess=preprocess) preprocess=preprocess)
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'TensorFlow Lite model exported successfully: %s' % tflite_filepath) 'TensorFlow Lite model exported successfully: %s' % tflite_filepath)

View File

@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
def setUp(self): def setUp(self):
super(CustomModelTest, self).setUp() super(CustomModelTest, self).setUp()
self.model = MockCustomModel(model_spec=None, shuffle=False) self._model = MockCustomModel(model_spec=None, shuffle=False)
self.model._model = test_util.build_model(input_shape=[4], num_classes=2) self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
def _check_nonempty_file(self, filepath): def _check_nonempty_file(self, filepath):
self.assertTrue(os.path.isfile(filepath)) self.assertTrue(os.path.isfile(filepath))
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
def test_export_tflite(self): def test_export_tflite(self):
export_path = os.path.join(self.get_temp_dir(), 'export/') export_path = os.path.join(self.get_temp_dir(), 'export/')
self.model.export_tflite(export_dir=export_path) self._model.export_tflite(export_dir=export_path)
self._check_nonempty_file(os.path.join(export_path, 'model.tflite')) self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -31,20 +31,6 @@ py_library(
], ],
) )
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
srcs_version = "PY3",
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":image_preprocessing"],
)
py_library( py_library(
name = "model_util", name = "model_util",
srcs = ["model_util.py"], srcs = ["model_util.py"],

View File

@ -56,7 +56,7 @@ class FocalLoss(tf.keras.losses.Loss):
class_weight: A weight to apply to the loss, one for each class. The class_weight: A weight to apply to the loss, one for each class. The
weight is applied for each input where the ground truth label matches. weight is applied for each input where the ground truth label matches.
""" """
super(tf.keras.losses.Loss, self).__init__() super().__init__()
# Used for clipping min/max values of probability values in y_pred to avoid # Used for clipping min/max values of probability values in y_pred to avoid
# NaNs and Infs in computation. # NaNs and Infs in computation.
self._epsilon = 1e-7 self._epsilon = 1e-7

View File

@ -104,8 +104,8 @@ def export_tflite(
quantization_config: Configuration for post-training quantization. quantization_config: Configuration for post-training quantization.
supported_ops: A list of supported ops in the converted TFLite file. supported_ops: A list of supported ops in the converted TFLite file.
preprocess: A callable to preprocess the representative dataset for preprocess: A callable to preprocess the representative dataset for
quantization. The callable takes three arguments in order: feature, quantization. The callable takes three arguments in order: feature, label,
label, and is_training. and is_training.
""" """
if tflite_filepath is None: if tflite_filepath is None:
raise ValueError( raise ValueError(

View File

@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
model = test_util.build_model(input_shape=[input_dim], num_classes=2) model = test_util.build_model(input_shape=[input_dim], num_classes=2)
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
model_util.export_tflite(model, tflite_file) model_util.export_tflite(model, tflite_file)
self._test_tflite(model, tflite_file, input_dim) test_util.test_tflite(
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
@parameterized.named_parameters( @parameterized.named_parameters(
dict( dict(
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
input_dim = 16 input_dim = 16
num_classes = 2 num_classes = 2
max_input_value = 5 max_input_value = 5
model = test_util.build_model([input_dim], num_classes) model = test_util.build_model(
input_shape=[input_dim], num_classes=num_classes)
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
model_util.export_tflite(model, tflite_file, config) model_util.export_tflite(
self._test_tflite( model=model, tflite_filepath=tflite_file, quantization_config=config)
model, tflite_file, input_dim, max_input_value, atol=1e-00)
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
def _test_tflite(self,
keras_model: tf.keras.Model,
tflite_model_file: str,
input_dim: int,
max_input_value: int = 1000,
atol: float = 1e-04):
random_input = test_util.create_random_sample(
size=[1, input_dim], high=max_input_value)
random_input = tf.convert_to_tensor(random_input)
self.assertTrue( self.assertTrue(
test_util.is_same_output( test_util.test_tflite(
tflite_model_file, keras_model, random_input, atol=atol)) keras_model=model,
tflite_file=tflite_file,
size=[1, input_dim],
high=max_input_value,
atol=1e-00))
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
keras_output = keras_model.predict_on_batch(input_tensors) keras_output = keras_model.predict_on_batch(input_tensors)
return np.allclose(lite_output, keras_output, atol=atol) return np.allclose(lite_output, keras_output, atol=atol)
def test_tflite(keras_model: tf.keras.Model,
tflite_file: str,
size: Union[int, List[int]],
high: float = 1,
atol: float = 1e-04) -> bool:
"""Verifies if the output of TFLite model and TF Keras model are identical.
Args:
keras_model: Input TensorFlow Keras model.
tflite_file: Input TFLite model file.
size: Size of the input tesnor.
high: Higher boundary of the values in input tensors.
atol: Absolute tolerance of the difference between the outputs of Keras
model and TFLite model.
Returns:
True if the output of TFLite model and TF Keras model are identical.
Otherwise, False.
"""
random_input = create_random_sample(size=size, high=high)
random_input = tf.convert_to_tensor(random_input)
return is_same_output(
tflite_file=tflite_file,
keras_model=keras_model,
input_tensors=random_input,
atol=atol)

View File

@ -0,0 +1,33 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "image_preprocessing",
srcs = ["image_preprocessing.py"],
)
py_test(
name = "image_preprocessing_test",
srcs = ["image_preprocessing_test.py"],
deps = [":image_preprocessing"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -13,11 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ImageNet preprocessing.""" """ImageNet preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import tensorflow as tf import tensorflow as tf
IMAGE_SIZE = 224 IMAGE_SIZE = 224

View File

@ -12,15 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.core.utils import image_preprocessing from mediapipe.model_maker.python.vision.core import image_preprocessing
def _get_preprocessed_image(preprocessor, is_training=False): def _get_preprocessed_image(preprocessor, is_training=False):

View File

@ -78,9 +78,9 @@ py_library(
":train_image_classifier_lib", ":train_image_classifier_lib",
"//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/core/tasks:classifier", "//mediapipe/model_maker/python/core/tasks:classifier",
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
"//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:model_util",
"//mediapipe/model_maker/python/core/utils:quantization", "//mediapipe/model_maker/python/core/utils:quantization",
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
], ],
) )

View File

@ -16,7 +16,7 @@
import os import os
import random import random
from typing import List, Optional, Tuple from typing import List, Optional
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
@ -24,12 +24,12 @@ from mediapipe.model_maker.python.core.data import classification_dataset
def _load_image(path: str) -> tf.Tensor: def _load_image(path: str) -> tf.Tensor:
"""Loads image.""" """Loads a jpeg/png image and returns an image tensor."""
image_raw = tf.io.read_file(path) image_raw = tf.io.read_file(path)
image_tensor = tf.cond( image_tensor = tf.cond(
tf.image.is_jpeg(image_raw), tf.io.is_jpeg(image_raw),
lambda: tf.image.decode_jpeg(image_raw, channels=3), lambda: tf.io.decode_jpeg(image_raw, channels=3),
lambda: tf.image.decode_png(image_raw, channels=3)) lambda: tf.io.decode_png(image_raw, channels=3))
return image_tensor return image_tensor
@ -60,11 +60,10 @@ class Dataset(classification_dataset.ClassificationDataset):
Args: Args:
dirname: Name of the directory containing the data files. dirname: Name of the directory containing the data files.
shuffle: boolean, if shuffle, random shuffle data. shuffle: boolean, if true, random shuffle data.
Returns: Returns:
Dataset containing images and labels and other related info. Dataset containing images and labels and other related info.
Raises: Raises:
ValueError: if the input data directory is empty. ValueError: if the input data directory is empty.
""" """
@ -85,55 +84,26 @@ class Dataset(classification_dataset.ClassificationDataset):
name for name in os.listdir(data_root) name for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))) if os.path.isdir(os.path.join(data_root, name)))
all_label_size = len(label_names) all_label_size = len(label_names)
label_to_index = dict( index_by_label = dict(
(name, index) for index, name in enumerate(label_names)) (name, index) for index, name in enumerate(label_names))
all_image_labels = [ all_image_labels = [
label_to_index[os.path.basename(os.path.dirname(path))] index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths for path in all_image_paths
] ]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
autotune = tf.data.AUTOTUNE image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
# Loads label. # Load label
label_ds = tf.data.Dataset.from_tensor_slices( label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64)) tf.cast(all_image_labels, tf.int64))
# Creates a dataset if (image, label) pairs. # Create a dataset if (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
tf.compat.v1.logging.info( tf.compat.v1.logging.info(
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
all_label_size, ', '.join(label_names)) all_label_size, ', '.join(label_names))
return Dataset(image_label_ds, all_image_size, label_names) return Dataset(
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
@classmethod
def load_tf_dataset(
cls, name: str
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset],
Optional[classification_dataset.ClassificationDataset]]:
"""Loads data from tensorflow_datasets.
Args:
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
documentation of tfds.load for more details.
Returns:
A tuple of Datasets for the train/validation/test.
Raises:
ValueError: if the input tf dataset does not have train/validation/test
labels.
"""
data, info = tfds.load(name, with_info=True)
if 'label' not in info.features:
raise ValueError('info.features need to contain \'label\' key.')
label_names = info.features['label'].names
train_data = _create_data('train', data, info, label_names)
validation_data = _create_data('validation', data, info, label_names)
test_data = _create_data('test', data, info, label_names)
return train_data, validation_data, test_data

View File

@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
def test_split(self): def test_split(self):
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = dataset.Dataset(ds, 4, ['pos', 'neg']) data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg'])
train_data, test_data = data.split(0.5) train_data, test_data = data.split(fraction=0.5)
self.assertLen(train_data, 2) self.assertLen(train_data, 2)
for i, elem in enumerate(train_data._dataset): for i, elem in enumerate(train_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertEqual(train_data.num_classes, 2) self.assertEqual(train_data.num_classes, 2)
self.assertEqual(train_data.index_to_label, ['pos', 'neg']) self.assertEqual(train_data.index_by_label, ['pos', 'neg'])
self.assertLen(test_data, 2) self.assertLen(test_data, 2)
for i, elem in enumerate(test_data._dataset): for i, elem in enumerate(test_data._dataset):
self.assertTrue((elem.numpy() == np.array([i, 0])).all()) self.assertTrue((elem.numpy() == np.array([i, 0])).all())
self.assertEqual(test_data.num_classes, 2) self.assertEqual(test_data.num_classes, 2)
self.assertEqual(test_data.index_to_label, ['pos', 'neg']) self.assertEqual(test_data.index_by_label, ['pos', 'neg'])
def test_from_folder(self): def test_from_folder(self):
data = dataset.Dataset.from_folder(self.image_path) data = dataset.Dataset.from_folder(dirname=self.image_path)
self.assertLen(data, 2) self.assertLen(data, 2)
self.assertEqual(data.num_classes, 2) self.assertEqual(data.num_classes, 2)
self.assertEqual(data.index_to_label, ['daisy', 'tulips']) self.assertEqual(data.index_by_label, ['daisy', 'tulips'])
for image, label in data.gen_tf_dataset(): for image, label in data.gen_tf_dataset():
self.assertTrue(label.numpy() == 1 or label.numpy() == 0) self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
if label.numpy() == 0: if label.numpy() == 0:
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(train_data, 1034) self.assertLen(train_data, 1034)
self.assertEqual(train_data.num_classes, 3) self.assertEqual(train_data.num_classes, 3)
self.assertEqual(train_data.index_to_label, self.assertEqual(train_data.index_by_label,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(validation_data, 133) self.assertLen(validation_data, 133)
self.assertEqual(validation_data.num_classes, 3) self.assertEqual(validation_data.num_classes, 3)
self.assertEqual(validation_data.index_to_label, self.assertEqual(validation_data.index_by_label,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset) self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
self.assertLen(test_data, 128) self.assertLen(test_data, 128)
self.assertEqual(test_data.num_classes, 3) self.assertEqual(test_data.num_classes, 3)
self.assertEqual(test_data.index_to_label, self.assertEqual(test_data.index_by_label,
['angular_leaf_spot', 'bean_rust', 'healthy']) ['angular_leaf_spot', 'bean_rust', 'healthy'])

View File

@ -20,9 +20,9 @@ import tensorflow_hub as hub
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
from mediapipe.model_maker.python.core.tasks import classifier from mediapipe.model_maker.python.core.tasks import classifier
from mediapipe.model_maker.python.core.utils import image_preprocessing
from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import model_util
from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.vision.core import image_preprocessing
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
class ImageClassifier(classifier.Classifier): class ImageClassifier(classifier.Classifier):
"""ImageClassifier for building image classification model.""" """ImageClassifier for building image classification model."""
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any], def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any],
hparams: hp.HParams): hparams: hp.HParams):
"""Initializes ImageClassifier class. """Initializes ImageClassifier class.
Args: Args:
model_spec: Specification for the model. model_spec: Specification for the model.
index_to_label: A list that maps from index to label class name. index_by_label: A list that maps from index to label class name.
hparams: The hyperparameters for training image classifier. hparams: The hyperparameters for training image classifier.
""" """
super(ImageClassifier, self).__init__( super().__init__(
model_spec=model_spec, model_spec=model_spec,
index_to_label=index_to_label, index_by_label=index_by_label,
shuffle=hparams.shuffle, shuffle=hparams.shuffle,
full_train=hparams.do_fine_tuning) full_train=hparams.do_fine_tuning)
self._hparams = hparams self._hparams = hparams
@ -81,7 +81,7 @@ class ImageClassifier(classifier.Classifier):
spec = ms.SupportedModels.get(model_spec) spec = ms.SupportedModels.get(model_spec)
image_classifier = cls( image_classifier = cls(
model_spec=spec, model_spec=spec,
index_to_label=train_data.index_to_label, index_by_label=train_data.index_by_label,
hparams=hparams) hparams=hparams)
image_classifier._create_model() image_classifier._create_model()

View File

@ -60,31 +60,16 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
model_spec=image_classifier.SupportedModels.MOBILENET_V2, model_spec=image_classifier.SupportedModels.MOBILENET_V2,
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)), train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='resnet_50',
model_spec=image_classifier.SupportedModels.RESNET_50,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict( dict(
testcase_name='efficientnet_lite0', testcase_name='efficientnet_lite0',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)), train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite1',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict( dict(
testcase_name='efficientnet_lite2', testcase_name='efficientnet_lite2',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2, model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
hparams=image_classifier.HParams( hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)), train_epochs=1, batch_size=1, shuffle=True)),
dict(
testcase_name='efficientnet_lite3',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3,
hparams=image_classifier.HParams(
train_epochs=1, batch_size=1, shuffle=True)),
dict( dict(
testcase_name='efficientnet_lite4', testcase_name='efficientnet_lite4',
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4, model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,

View File

@ -48,34 +48,17 @@ mobilenet_v2_spec = functools.partial(
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
name='mobilenet_v2') name='mobilenet_v2')
resnet_50_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
name='resnet_50')
efficientnet_lite0_spec = functools.partial( efficientnet_lite0_spec = functools.partial(
ModelSpec, ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
name='efficientnet_lite0') name='efficientnet_lite0')
efficientnet_lite1_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
input_image_shape=[240, 240],
name='efficientnet_lite1')
efficientnet_lite2_spec = functools.partial( efficientnet_lite2_spec = functools.partial(
ModelSpec, ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
input_image_shape=[260, 260], input_image_shape=[260, 260],
name='efficientnet_lite2') name='efficientnet_lite2')
efficientnet_lite3_spec = functools.partial(
ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
input_image_shape=[280, 280],
name='efficientnet_lite3')
efficientnet_lite4_spec = functools.partial( efficientnet_lite4_spec = functools.partial(
ModelSpec, ModelSpec,
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
@ -88,11 +71,8 @@ efficientnet_lite4_spec = functools.partial(
class SupportedModels(enum.Enum): class SupportedModels(enum.Enum):
"""Image classifier model supported by model maker.""" """Image classifier model supported by model maker."""
MOBILENET_V2 = mobilenet_v2_spec MOBILENET_V2 = mobilenet_v2_spec
RESNET_50 = resnet_50_spec
EFFICIENTNET_LITE0 = efficientnet_lite0_spec EFFICIENTNET_LITE0 = efficientnet_lite0_spec
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
EFFICIENTNET_LITE2 = efficientnet_lite2_spec EFFICIENTNET_LITE2 = efficientnet_lite2_spec
EFFICIENTNET_LITE3 = efficientnet_lite3_spec
EFFICIENTNET_LITE4 = efficientnet_lite4_spec EFFICIENTNET_LITE4 = efficientnet_lite4_spec
@classmethod @classmethod

View File

@ -30,36 +30,18 @@ class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
expected_name='mobilenet_v2', expected_name='mobilenet_v2',
expected_input_image_shape=[224, 224]), expected_input_image_shape=[224, 224]),
dict(
testcase_name='resnet_50_spec_test',
model_spec=ms.resnet_50_spec,
expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
expected_name='resnet_50',
expected_input_image_shape=[224, 224]),
dict( dict(
testcase_name='efficientnet_lite0_spec_test', testcase_name='efficientnet_lite0_spec_test',
model_spec=ms.efficientnet_lite0_spec, model_spec=ms.efficientnet_lite0_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
expected_name='efficientnet_lite0', expected_name='efficientnet_lite0',
expected_input_image_shape=[224, 224]), expected_input_image_shape=[224, 224]),
dict(
testcase_name='efficientnet_lite1_spec_test',
model_spec=ms.efficientnet_lite1_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
expected_name='efficientnet_lite1',
expected_input_image_shape=[240, 240]),
dict( dict(
testcase_name='efficientnet_lite2_spec_test', testcase_name='efficientnet_lite2_spec_test',
model_spec=ms.efficientnet_lite2_spec, model_spec=ms.efficientnet_lite2_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
expected_name='efficientnet_lite2', expected_name='efficientnet_lite2',
expected_input_image_shape=[260, 260]), expected_input_image_shape=[260, 260]),
dict(
testcase_name='efficientnet_lite3_spec_test',
model_spec=ms.efficientnet_lite3_spec,
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
expected_name='efficientnet_lite3',
expected_input_image_shape=[280, 280]),
dict( dict(
testcase_name='efficientnet_lite4_spec_test', testcase_name='efficientnet_lite4_spec_test',
model_spec=ms.efficientnet_lite4_spec, model_spec=ms.efficientnet_lite4_spec,

View File

@ -92,3 +92,29 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
# TODO: Investigate rewriting the build rule to only link
# the Bert Preprocessor if it's needed.
cc_library(
name = "text_preprocessing_graph",
srcs = ["text_preprocessing_graph.cc"],
hdrs = ["text_preprocessing_graph.h"],
deps = [
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
"//mediapipe/framework:subgraph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)

View File

@ -17,8 +17,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
cc_library( cc_library(
name = "landmarks_detection", name = "rect",
hdrs = ["landmarks_detection.h"], hdrs = ["rect.h"],
) )
cc_library( cc_library(

View File

@ -13,26 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
#include <vector>
// Sturcts holding landmarks related data structure for hand landmarker, pose
// detector, face mesher, etc.
namespace mediapipe::tasks::components::containers { namespace mediapipe::tasks::components::containers {
// x and y are in [0,1] range with origin in top left in input image space. // Defines a rectangle, used e.g. as part of detection results or as input
// If model provides z, z is in the same scale as x. origin is in the center // region-of-interest.
// of the face. //
struct Landmark { // The coordinates are normalized wrt the image dimensions, i.e. generally in
float x; // [0,1] but they may exceed these bounds if describing a region overlapping the
float y; // image. The origin is on the top-left corner of the image.
float z; struct Rect {
};
// [0, 1] range in input image space
struct Bound {
float left; float left;
float top; float top;
float right; float right;
@ -40,4 +32,4 @@ struct Bound {
}; };
} // namespace mediapipe::tasks::components::containers } // namespace mediapipe::tasks::components::containers
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_

View File

@ -73,6 +73,7 @@ cc_library(
srcs = ["model_task_graph.cc"], srcs = ["model_task_graph.cc"],
hdrs = ["model_task_graph.h"], hdrs = ["model_task_graph.h"],
deps = [ deps = [
":model_asset_bundle_resources",
":model_resources", ":model_resources",
":model_resources_cache", ":model_resources_cache",
":model_resources_calculator", ":model_resources_calculator",
@ -163,6 +164,7 @@ cc_library_with_tflite(
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
], ],
deps = [ deps = [
":model_asset_bundle_resources",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -39,12 +40,16 @@ ModelResourcesCache::ModelResourcesCache(
graph_op_resolver_packet_ = graph_op_resolver_packet_ =
api2::PacketAdopting<tflite::OpResolver>(std::move(graph_op_resolver)); api2::PacketAdopting<tflite::OpResolver>(std::move(graph_op_resolver));
} }
}; }
bool ModelResourcesCache::Exists(const std::string& tag) const { bool ModelResourcesCache::Exists(const std::string& tag) const {
return model_resources_collection_.contains(tag); return model_resources_collection_.contains(tag);
} }
bool ModelResourcesCache::ModelAssetBundleExists(const std::string& tag) const {
return model_asset_bundle_resources_collection_.contains(tag);
}
absl::Status ModelResourcesCache::AddModelResources( absl::Status ModelResourcesCache::AddModelResources(
std::unique_ptr<ModelResources> model_resources) { std::unique_ptr<ModelResources> model_resources) {
if (model_resources == nullptr) { if (model_resources == nullptr) {
@ -94,6 +99,62 @@ absl::StatusOr<const ModelResources*> ModelResourcesCache::GetModelResources(
return model_resources_collection_.at(tag).get(); return model_resources_collection_.at(tag).get();
} }
absl::Status ModelResourcesCache::AddModelAssetBundleResources(
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources) {
if (model_asset_bundle_resources == nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ModelAssetBundleResources object is null.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
const std::string& tag = model_asset_bundle_resources->GetTag();
if (tag.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ModelAssetBundleResources must have a non-empty tag.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
if (ModelAssetBundleExists(tag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute(
"ModelAssetBundleResources with tag \"$0\" already exists.", tag),
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
model_asset_bundle_resources_collection_.emplace(
tag, std::move(model_asset_bundle_resources));
return absl::OkStatus();
}
absl::Status ModelResourcesCache::AddModelAssetBundleResourcesCollection(
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
model_asset_bundle_resources_collection) {
for (auto& model_bundle_resources : model_asset_bundle_resources_collection) {
MP_RETURN_IF_ERROR(
AddModelAssetBundleResources(std::move(model_bundle_resources)));
}
return absl::OkStatus();
}
absl::StatusOr<const ModelAssetBundleResources*>
ModelResourcesCache::GetModelAssetBundleResources(
const std::string& tag) const {
if (tag.empty()) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"ModelAssetBundleResources must be retrieved with a non-empty tag.",
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
if (!ModelAssetBundleExists(tag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::Substitute(
"ModelAssetBundleResources with tag \"$0\" does not exist.", tag),
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
}
return model_asset_bundle_resources_collection_.at(tag).get();
}
absl::StatusOr<api2::Packet<tflite::OpResolver>> absl::StatusOr<api2::Packet<tflite::OpResolver>>
ModelResourcesCache::GetGraphOpResolverPacket() const { ModelResourcesCache::GetGraphOpResolverPacket() const {
if (graph_op_resolver_packet_.IsEmpty()) { if (graph_op_resolver_packet_.IsEmpty()) {

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -46,6 +47,10 @@ class ModelResourcesCache {
// Returns whether the tag exists in the model resources cache. // Returns whether the tag exists in the model resources cache.
bool Exists(const std::string& tag) const; bool Exists(const std::string& tag) const;
// Returns whether the tag of the model asset bundle exists in the model
// resources cache.
bool ModelAssetBundleExists(const std::string& tag) const;
// Adds a ModelResources object into the cache. // Adds a ModelResources object into the cache.
// The tag of the ModelResources must be unique; the ownership of the // The tag of the ModelResources must be unique; the ownership of the
// ModelResources will be transferred into the cache. // ModelResources will be transferred into the cache.
@ -62,6 +67,23 @@ class ModelResourcesCache {
absl::StatusOr<const ModelResources*> GetModelResources( absl::StatusOr<const ModelResources*> GetModelResources(
const std::string& tag) const; const std::string& tag) const;
// Adds a ModelAssetBundleResources object into the cache.
// The tag of the ModelAssetBundleResources must be unique; the ownership of
// the ModelAssetBundleResources will be transferred into the cache.
absl::Status AddModelAssetBundleResources(
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources);
// Adds a collection of the ModelAssetBundleResources objects into the cache.
// The tag of the each ModelAssetBundleResources must be unique; the ownership
// of every ModelAssetBundleResources will be transferred into the cache.
absl::Status AddModelAssetBundleResourcesCollection(
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
model_asset_bundle_resources_collection);
// Retrieves a const ModelAssetBundleResources pointer by the unique tag.
absl::StatusOr<const ModelAssetBundleResources*> GetModelAssetBundleResources(
const std::string& tag) const;
// Retrieves the graph op resolver packet. // Retrieves the graph op resolver packet.
absl::StatusOr<api2::Packet<tflite::OpResolver>> GetGraphOpResolverPacket() absl::StatusOr<api2::Packet<tflite::OpResolver>> GetGraphOpResolverPacket()
const; const;
@ -73,6 +95,11 @@ class ModelResourcesCache {
// A collection of ModelResources objects for the models in the graph. // A collection of ModelResources objects for the models in the graph.
absl::flat_hash_map<std::string, std::unique_ptr<ModelResources>> absl::flat_hash_map<std::string, std::unique_ptr<ModelResources>>
model_resources_collection_; model_resources_collection_;
// A collection of ModelAssetBundleResources objects for the model bundles in
// the graph.
absl::flat_hash_map<std::string, std::unique_ptr<ModelAssetBundleResources>>
model_asset_bundle_resources_collection_;
}; };
// Global service for mediapipe task model resources cache. // Global service for mediapipe task model resources cache.

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
@ -70,6 +71,17 @@ std::string CreateModelResourcesTag(const CalculatorGraphConfig::Node& node) {
node_type); node_type);
} }
std::string CreateModelAssetBundleResourcesTag(
const CalculatorGraphConfig::Node& node) {
std::vector<std::string> names = absl::StrSplit(node.name(), "__");
std::string node_type = node.calculator();
std::replace(node_type.begin(), node_type.end(), '.', '_');
absl::AsciiStrToLower(&node_type);
return absl::StrFormat("%s_%s_model_asset_bundle_resources",
names.back().empty() ? "unnamed" : names.back(),
node_type);
}
} // namespace } // namespace
// Defines the mediapipe task inference unit as a MediaPipe subgraph that // Defines the mediapipe task inference unit as a MediaPipe subgraph that
@ -168,6 +180,38 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
return model_resources_cache_service.GetObject().GetModelResources(tag); return model_resources_cache_service.GetObject().GetModelResources(tag);
} }
absl::StatusOr<const ModelAssetBundleResources*>
ModelTaskGraph::CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
bool has_file_pointer_meta = external_file->has_file_pointer_meta();
// if external file is set by file pointer, no need to add the model asset
// bundle resources into the model resources service since the memory is
// not owned by this model asset bundle resources.
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
ASSIGN_OR_RETURN(
local_model_asset_bundle_resources_,
ModelAssetBundleResources::Create("", std::move(external_file)));
if (!has_file_pointer_meta) {
LOG(WARNING)
<< "A local ModelResources object is created. Please consider using "
"ModelResourcesCacheService to cache the created ModelResources "
"object in the CalculatorGraph.";
}
return local_model_asset_bundle_resources_.get();
}
const std::string tag =
CreateModelAssetBundleResourcesTag(sc->OriginalNode());
ASSIGN_OR_RETURN(
auto model_bundle_resources,
ModelAssetBundleResources::Create(tag, std::move(external_file)));
MP_RETURN_IF_ERROR(
model_resources_cache_service.GetObject().AddModelAssetBundleResources(
std::move(model_bundle_resources)));
return model_resources_cache_service.GetObject().GetModelAssetBundleResources(
tag);
}
GenericNode& ModelTaskGraph::AddInference( GenericNode& ModelTaskGraph::AddInference(
const ModelResources& model_resources, const ModelResources& model_resources,
const proto::Acceleration& acceleration, Graph& graph) const { const proto::Acceleration& acceleration, Graph& graph) const {

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/subgraph.h" #include "mediapipe/framework/subgraph.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
@ -78,6 +79,35 @@ class ModelTaskGraph : public Subgraph {
absl::StatusOr<const ModelResources*> CreateModelResources( absl::StatusOr<const ModelResources*> CreateModelResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file); SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
// If the model resources graph service is available, creates a model asset
// bundle resources object from the subgraph context, and caches the created
// model asset bundle resources into the model resources graph service on
// success. Otherwise, creates a local model asset bundle resources object
// that can only be used in the graph construction stage. The returned model
// resources pointer will provide graph authors with the access to extracted
// model files.
template <typename Options>
absl::StatusOr<const ModelAssetBundleResources*>
CreateModelAssetBundleResources(SubgraphContext* sc) {
auto external_file = std::make_unique<proto::ExternalFile>();
external_file->Swap(sc->MutableOptions<Options>()
->mutable_base_options()
->mutable_model_asset());
return CreateModelAssetBundleResources(sc, std::move(external_file));
}
// If the model resources graph service is available, creates a model asset
// bundle resources object from the subgraph context, and caches the created
// model asset bundle resources into the model resources graph service on
// success. Otherwise, creates a local model asset bundle resources object
// that can only be used in the graph construction stage. Note that the
// external file contents will be moved into the model asset bundle resources
// object on creation. The returned model asset bundle resources pointer will
// provide graph authors with the access to extracted model files.
absl::StatusOr<const ModelAssetBundleResources*>
CreateModelAssetBundleResources(
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
// Inserts a mediapipe task inference subgraph into the provided // Inserts a mediapipe task inference subgraph into the provided
// GraphBuilder. The returned node provides the following interfaces to the // GraphBuilder. The returned node provides the following interfaces to the
// the rest of the graph: // the rest of the graph:
@ -95,6 +125,9 @@ class ModelTaskGraph : public Subgraph {
private: private:
std::unique_ptr<ModelResources> local_model_resources_; std::unique_ptr<ModelResources> local_model_resources_;
std::unique_ptr<ModelAssetBundleResources>
local_model_asset_bundle_resources_;
}; };
} // namespace core } // namespace core

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include <string>
#include "absl/cleanup/cleanup.h" #include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "absl/status/status.h" #include "absl/status/status.h"
@ -162,12 +164,16 @@ absl::Status ExtractFilesfromZipFile(
return absl::OkStatus(); return absl::OkStatus();
} }
void SetExternalFile(const std::string_view& file_content, void SetExternalFile(const absl::string_view& file_content,
core::proto::ExternalFile* model_file) { core::proto::ExternalFile* model_file, bool is_copy) {
auto pointer = reinterpret_cast<uint64_t>(file_content.data()); if (is_copy) {
std::string str_content{file_content};
model_file->mutable_file_pointer_meta()->set_pointer(pointer); model_file->set_file_content(str_content);
model_file->mutable_file_pointer_meta()->set_length(file_content.length()); } else {
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
}
} }
} // namespace metadata } // namespace metadata

View File

@ -35,10 +35,13 @@ absl::Status ExtractFilesfromZipFile(
const char* buffer_data, const size_t buffer_size, const char* buffer_data, const size_t buffer_size,
absl::flat_hash_map<std::string, absl::string_view>* files); absl::flat_hash_map<std::string, absl::string_view>* files);
// Set file_pointer_meta in ExternalFile which is the pointer points to location // Set the ExternalFile object by file_content in memory. By default,
// of a file in memory by file_content. // `is_copy=false` which means to set `file_pointer_meta` in ExternalFile which
void SetExternalFile(const std::string_view& file_content, // is the pointer points to location of a file in memory. Otherwise, if
core::proto::ExternalFile* model_file); // `is_copy=true`, copy the memory into `file_content` in ExternalFile.
void SetExternalFile(const absl::string_view& file_content,
core::proto::ExternalFile* model_file,
bool is_copy = false);
} // namespace metadata } // namespace metadata
} // namespace tasks } // namespace tasks

View File

@ -0,0 +1,84 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "text_classifier_graph",
srcs = ["text_classifier_graph.cc"],
deps = [
"//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_resources_calculator",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_library(
name = "text_classifier",
srcs = ["text_classifier.cc"],
hdrs = ["text_classifier.h"],
deps = [
":text_classifier_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:builder",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:base_task_api",
"//mediapipe/tasks/cc/core:task_api_factory",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)
cc_library(
name = "text_classifier_test_utils",
srcs = ["text_classifier_test_utils.cc"],
hdrs = ["text_classifier_test_utils.h"],
visibility = ["//visibility:private"],
deps = [
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite:mutable_op_resolver",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
],
)

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "text_classifier_graph_options_proto",
srcs = ["text_classifier_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe.tasks.text.text_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
option java_package = "com.google.mediapipe.tasks.text.textclassifier.proto";
option java_outer_classname = "TextClassifierGraphOptionsProto";
message TextClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional TextClassifierGraphOptions ext = 462704549;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1;
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
optional components.processors.proto.ClassifierOptions classifier_options = 2;
}

View File

@ -0,0 +1,104 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/task_api_factory.h"
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
namespace {
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kTextStreamName[] = "text_in";
constexpr char kTextTag[] = "TEXT";
constexpr char kClassificationResultStreamName[] = "classification_result_out";
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
// Creates a MediaPipe graph config that only contains a single subgraph node of
// type "TextClassifierGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<proto::TextClassifierGraphOptions> options) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kSubgraphTypeName);
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag);
return graph.GetConfig();
}
// Converts the user-facing TextClassifierOptions struct to the internal
// TextClassifierGraphOptions proto.
std::unique_ptr<proto::TextClassifierGraphOptions>
ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) {
auto options_proto = std::make_unique<proto::TextClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
auto classifier_options_proto =
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get());
return options_proto;
}
} // namespace
absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
std::unique_ptr<TextClassifierOptions> options) {
auto options_proto = ConvertTextClassifierOptionsToProto(options.get());
return core::TaskApiFactory::Create<TextClassifier,
proto::TextClassifierGraphOptions>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver));
}
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
absl::string_view text) {
ASSIGN_OR_RETURN(
auto output_packets,
runner_->Process(
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
return output_packets[kClassificationResultStreamName]
.Get<ClassificationResult>();
}
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,96 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/base_task_api.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
// The options for configuring a MediaPipe text classifier task.
struct TextClassifierOptions {
// Base options for configuring MediaPipe Tasks, such as specifying the model
// file with metadata, accelerator options, op resolver, etc.
tasks::core::BaseOptions base_options;
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
components::processors::ClassifierOptions classifier_options;
};
// Performs classification on text.
//
// This API expects a TFLite model with (optional) TFLite Model Metadata that
// contains the mandatory (described below) input tensors, output tensor,
// and the optional (but recommended) label items as AssociatedFiles with type
// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for
// models with int32 input tensors because it contains the input process unit
// for the model's Tokenizer. No metadata is required for models with string
// input tensors.
//
// Input tensors:
// (kTfLiteInt32)
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing
// the input ids, segment ids, and mask ids
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
// input ids
// or (kTfLiteString)
// - 1 input tensor that is shapeless or has shape [1] containing the input
// string
// At least one output tensor with:
// (kTfLiteFloat32/kBool)
// - `[1 x N]` array with `N` represents the number of categories.
// - optional (but recommended) label items as AssociatedFiles with type
// TENSOR_AXIS_LABELS, containing one label per line. The first such
// AssociatedFile (if any) is used to fill the `category_name` field of the
// results. The `display_name` field is filled from the AssociatedFile (if
// any) whose locale matches the `display_names_locale` field of the
// `TextClassifierOptions` used at creation time ("en" by default, i.e.
// English). If none of these are available, only the `index` field of the
// results will be filled.
class TextClassifier : core::BaseTaskApi {
public:
using BaseTaskApi::BaseTaskApi;
// Creates a TextClassifier from the provided `options`.
static absl::StatusOr<std::unique_ptr<TextClassifier>> Create(
std::unique_ptr<TextClassifierOptions> options);
// Performs classification on the input `text`.
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
absl::string_view text);
// Shuts down the TextClassifier when all the work is done.
absl::Status Close() { return runner_->Close(); }
};
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_

View File

@ -0,0 +1,162 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <string>
#include <type_traits>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources;
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kTextTag[] = "TEXT";
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
constexpr char kTensorsTag[] = "TENSORS";
} // namespace
// A "TextClassifierGraph" performs Natural Language classification (including
// BERT-based text classification).
// - Accepts input text and outputs classification results on CPU.
//
// Inputs:
// TEXT - std::string
// Input text to perform classification on.
//
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The aggregated classification result object that has 3 dimensions:
// (classification head, classification timestamp, classification category).
//
// Example:
// node {
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
// input_stream: "TEXT:text_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// options {
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
// {
// base_options {
// model_asset {
// file_name: "/path/to/model.tflite"
// }
// }
// }
// }
// }
class TextClassifierGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(
const ModelResources* model_resources,
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
Source<ClassificationResult> classification_result_out,
BuildTextClassifierTask(
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
graph[Input<std::string>(kTextTag)], graph));
classification_result_out >>
graph[Output<ClassificationResult>(kClassificationResultTag)];
return graph.GetConfig();
}
private:
// Adds a mediapipe TextClassifier task graph into the provided
// builder::Graph instance. The TextClassifier task takes an input
// text (std::string) and returns one classification result per output head
// specified by the model.
//
// task_options: the mediapipe tasks TextClassifierGraphOptions proto.
// model_resources: the ModelResources object initialized from a
// TextClassifier model file with model metadata.
// text_in: (std::string) stream to run text classification on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
const proto::TextClassifierGraphOptions& task_options,
const ModelResources& model_resources, Source<std::string> text_in,
Graph& graph) {
// Adds preprocessing calculators and connects them to the text input
// stream.
auto& preprocessing =
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
model_resources,
preprocessing.GetOptions<
tasks::components::proto::TextPreprocessingGraphOptions>()));
text_in >> preprocessing.In(kTextTag);
// Adds both InferenceCalculator and ModelResourcesCalculator.
auto& inference = AddInference(
model_resources, task_options.base_options().acceleration(), graph);
// The metadata extractor side-output comes from the
// ModelResourcesCalculator.
inference.SideOut(kMetadataExtractorTag) >>
preprocessing.SideIn(kMetadataExtractorTag);
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
// Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, task_options.classifier_options(),
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the aggregated classification result as the subgraph output
// stream.
return postprocessing[Output<ClassificationResult>(
kClassificationResultTag)];
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::text::text_classifier::TextClassifierGraph);
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,238 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
#include <cmath>
#include <cstdlib>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace text_classifier {
namespace {
using ::mediapipe::EqualsProto;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::kMediaPipeTasksPayload;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr;
using ::testing::Optional;
using ::testing::proto::Approximately;
using ::testing::proto::IgnoringRepeatedFieldOrdering;
using ::testing::proto::Partially;
constexpr float kEpsilon = 0.001;
constexpr int kMaxSeqLen = 128;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
constexpr char kTestRegexModelPath[] =
"test_model_text_classifier_with_regex_tokenizer.tflite";
constexpr char kStringToBoolModelPath[] =
"test_model_text_classifier_bool_output.tflite";
std::string GetFullPath(absl::string_view file_name) {
return JoinPath("./", kTestDataDirectory, file_name);
}
class TextClassifierTest : public tflite_shims::testing::Test {};
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) {
auto options = std::make_unique<TextClassifierOptions>();
StatusOr<std::unique_ptr<TextClassifier>> classifier =
TextClassifier::Create(std::move(options));
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
classifier.status().message(),
HasSubstr("ExternalFile must specify at least one of 'file_content', "
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(TextClassifierTest, CreateFailsWithMissingModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
StatusOr<std::unique_ptr<TextClassifier>> classifier =
TextClassifier::Create(std::move(options));
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
EXPECT_THAT(classifier.status().message(),
HasSubstr("Unable to open file at"));
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
Optional(absl::Cord(absl::StrCat(
MediaPipeTasksStatus::kRunnerInitializationError))));
}
TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
}
TEST_F(TextClassifierTest, TextClassifierWithBert) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult negative_result,
classifier->Classify("unflinchingly bleak and desperate"));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.956 }
categories { category_name: "positive" score: 0.044 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("it's a charming and often affecting journey"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.0 }
categories { category_name: "positive" score: 1.0 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
classifier->Classify("What a waste of my time."));
ASSERT_THAT(negative_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.813 }
categories { category_name: "Positive" score: 0.187 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK_AND_ASSIGN(
ClassificationResult positive_result,
classifier->Classify("This is the best movie Ive seen in recent years. "
"Strongly recommend it!"));
ASSERT_THAT(positive_result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "Negative" score: 0.487 }
categories { category_name: "Positive" score: 0.513 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
options->base_options.op_resolver = CreateCustomResolver();
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify("hello"));
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
classifications {
entries {
categories { index: 1 score: 1 }
categories { index: 0 score: 1 }
categories { index: 2 score: 0 }
}
}
)pb"))));
}
TEST_F(TextClassifierTest, BertLongPositive) {
std::stringstream ss_for_positive_review;
ss_for_positive_review
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kMaxSeqLen; ++i) {
ss_for_positive_review << " long";
}
ss_for_positive_review << " movie review";
auto options = std::make_unique<TextClassifierOptions>();
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
TextClassifier::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
classifier->Classify(ss_for_positive_review.str()));
ASSERT_THAT(result,
Partially(IgnoringRepeatedFieldOrdering(Approximately(
EqualsProto(R"pb(
classifications {
entries {
categories { category_name: "negative" score: 0.014 }
categories { category_name: "positive" score: 0.986 }
}
}
)pb"),
kEpsilon))));
MP_ASSERT_OK(classifier->Close());
}
} // namespace
} // namespace text_classifier
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,131 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
#include <cstring>
#include <memory>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/mutable_op_resolver.h"
#include "tensorflow/lite/portable_type_to_tflitetype.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe {
namespace tasks {
namespace text {
namespace {
using ::mediapipe::tasks::CreateStatusWithPayload;
using ::tflite::GetInput;
using ::tflite::GetOutput;
using ::tflite::GetString;
using ::tflite::StringRef;
constexpr absl::string_view kInputStr = "hello";
constexpr bool kBooleanData[] = {true, true, false};
constexpr size_t kBooleanDataSize = std::size(kBooleanData);
// Checks and returns type of a tensor, fails if tensor type is not T.
template <typename T>
absl::StatusOr<T*> AssertAndReturnTypedTensor(const TfLiteTensor* tensor) {
if (!tensor->data.raw) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("Tensor (%s) has no raw data.", tensor->name));
}
// Checks if data type of tensor is T and returns the pointer casted to T if
// applicable, returns nullptr if tensor type is not T.
// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType.
if (tensor->type == tflite::typeToTfLiteType<T>()) {
return reinterpret_cast<T*>(tensor->data.raw);
}
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.",
tensor->name, tflite::typeToTfLiteType<T>(),
tensor->bytes));
}
// Populates tensor with array of data, fails if data type doesn't match tensor
// type or they don't have the same number of elements.
template <typename T, typename = std::enable_if_t<
std::negation_v<std::is_same<T, std::string>>>>
absl::Status PopulateTensor(const T* data, int num_elements,
TfLiteTensor* tensor) {
ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor<T>(tensor));
size_t bytes = num_elements * sizeof(T);
if (tensor->bytes != bytes) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes,
bytes));
}
std::memcpy(v, data, bytes);
return absl::OkStatus();
}
TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
dims->data[0] = kBooleanDataSize;
return context->ResizeTensor(context, output, dims);
}
TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
TF_LITE_ENSURE(context, input_tensor != nullptr);
StringRef input_str_ref = GetString(input_tensor, 0);
std::string input_str(input_str_ref.str, input_str_ref.len);
if (input_str != kInputStr) {
return kTfLiteError;
}
TfLiteTensor* output = GetOutput(context, node, 0);
TF_LITE_ENSURE(context, output != nullptr);
TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok());
return kTfLiteOk;
}
// This custom op takes a string tensor in and outputs a bool tensor with
// value{true, true, false}, it's used to mimic a real text classification model
// which classifies a string into scores of different categories.
TfLiteRegistration* RegisterStringToBool() {
// Dummy implementation of custom OP
// This op takes string as input and outputs bool[]
static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr,
/* prepare= */ PrepareStringToBool,
/* invoke= */ InvokeStringToBool};
return &r;
}
} // namespace
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver() {
tflite::MutableOpResolver resolver;
resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool());
return std::make_unique<tflite::MutableOpResolver>(resolver);
}
} // namespace text
} // namespace tasks
} // namespace mediapipe

View File

@ -0,0 +1,35 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
#include <memory>
#include "tensorflow/lite/mutable_op_resolver.h"
namespace mediapipe {
namespace tasks {
namespace text {
// Create a custom MutableOpResolver to provide custom OP implementations to
// mimic classification behavior.
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver();
} // namespace text
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_

View File

@ -56,6 +56,7 @@ cc_library(
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
@ -91,6 +92,7 @@ cc_library(
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
@ -123,6 +125,7 @@ cc_library(
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result", "//mediapipe/tasks/cc/components/containers:gesture_recognition_result",

View File

@ -69,6 +69,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
@ -86,6 +87,7 @@ cc_test(
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <algorithm> #include <algorithm>
#include <cmath>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
@ -26,6 +27,7 @@ limitations under the License.
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
@ -38,6 +40,7 @@ namespace {
constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX";
constexpr int kFeaturesPerLandmark = 3; constexpr int kFeaturesPerLandmark = 3;
@ -62,6 +65,25 @@ absl::StatusOr<LandmarkListT> NormalizeLandmarkAspectRatio(
return normalized_landmarks; return normalized_landmarks;
} }
template <class LandmarkListT>
absl::StatusOr<LandmarkListT> RotateLandmarks(const LandmarkListT& landmarks,
float rotation) {
float cos = std::cos(rotation);
// Negate because Y-axis points down and not up.
float sin = std::sin(-rotation);
LandmarkListT rotated_landmarks;
for (int i = 0; i < landmarks.landmark_size(); ++i) {
const auto& old_landmark = landmarks.landmark(i);
float x = old_landmark.x() - 0.5;
float y = old_landmark.y() - 0.5;
auto* new_landmark = rotated_landmarks.add_landmark();
new_landmark->set_x(x * cos - y * sin + 0.5);
new_landmark->set_y(y * cos + x * sin + 0.5);
new_landmark->set_z(old_landmark.z());
}
return rotated_landmarks;
}
template <class LandmarkListT> template <class LandmarkListT>
absl::StatusOr<LandmarkListT> NormalizeObject(const LandmarkListT& landmarks, absl::StatusOr<LandmarkListT> NormalizeObject(const LandmarkListT& landmarks,
int origin_offset) { int origin_offset) {
@ -134,6 +156,13 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) {
NormalizeLandmarkAspectRatio(landmarks, width, height)); NormalizeLandmarkAspectRatio(landmarks, width, height));
} }
if (cc->Inputs().HasTag(kNormRectTag)) {
RET_CHECK(!cc->Inputs().Tag(kNormRectTag).IsEmpty());
const auto rotation =
cc->Inputs().Tag(kNormRectTag).Get<NormalizedRect>().rotation();
ASSIGN_OR_RETURN(landmarks, RotateLandmarks(landmarks, rotation));
}
const auto& options = cc->Options<LandmarksToMatrixCalculatorOptions>(); const auto& options = cc->Options<LandmarksToMatrixCalculatorOptions>();
if (options.object_normalization()) { if (options.object_normalization()) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
@ -163,6 +192,8 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) {
// WORLD_LANDMARKS - World 3d landmarks of one object. Use *either* // WORLD_LANDMARKS - World 3d landmarks of one object. Use *either*
// LANDMARKS or WORLD_LANDMARKS. // LANDMARKS or WORLD_LANDMARKS.
// IMAGE_SIZE - (width, height) of the image // IMAGE_SIZE - (width, height) of the image
// NORM_RECT - Optional NormalizedRect object whose 'rotation' field is used
// to rotate the landmarks.
// Output: // Output:
// LANDMARKS_MATRIX - Matrix for the landmarks. // LANDMARKS_MATRIX - Matrix for the landmarks.
// //
@ -185,6 +216,7 @@ class LandmarksToMatrixCalculator : public CalculatorBase {
cc->Inputs().Tag(kLandmarksTag).Set<NormalizedLandmarkList>().Optional(); cc->Inputs().Tag(kLandmarksTag).Set<NormalizedLandmarkList>().Optional();
cc->Inputs().Tag(kWorldLandmarksTag).Set<LandmarkList>().Optional(); cc->Inputs().Tag(kWorldLandmarksTag).Set<LandmarkList>().Optional();
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>().Optional(); cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>().Optional();
cc->Inputs().Tag(kNormRectTag).Set<NormalizedRect>().Optional();
cc->Outputs().Tag(kLandmarksMatrixTag).Set<Matrix>(); cc->Outputs().Tag(kLandmarksMatrixTag).Set<Matrix>();
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cmath>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
@ -23,6 +24,7 @@ limitations under the License.
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
@ -35,6 +37,7 @@ constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX";
constexpr char kNormRectTag[] = "NORM_RECT";
template <class LandmarkListT> template <class LandmarkListT>
LandmarkListT BuildPseudoLandmarks(int num_landmarks, int offset = 0) { LandmarkListT BuildPseudoLandmarks(int num_landmarks, int offset = 0) {
@ -54,6 +57,7 @@ struct Landmarks2dToMatrixCalculatorTestCase {
int object_normalization_origin_offset = -1; int object_normalization_origin_offset = -1;
float expected_cell_0_2; float expected_cell_0_2;
float expected_cell_1_5; float expected_cell_1_5;
float rotation;
}; };
using Landmarks2dToMatrixCalculatorTest = using Landmarks2dToMatrixCalculatorTest =
@ -68,6 +72,7 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) {
calculator: "LandmarksToMatrixCalculator" calculator: "LandmarksToMatrixCalculator"
input_stream: "LANDMARKS:landmarks" input_stream: "LANDMARKS:landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
input_stream: "NORM_RECT:norm_rect"
output_stream: "LANDMARKS_MATRIX:landmarks_matrix" output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
options { options {
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] { [mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
@ -91,6 +96,11 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) {
runner.MutableInputs() runner.MutableInputs()
->Tag(kImageSizeTag) ->Tag(kImageSizeTag)
.packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); .packets.push_back(Adopt(image_size.release()).At(Timestamp(0)));
auto norm_rect = std::make_unique<NormalizedRect>();
norm_rect->set_rotation(test_case.rotation);
runner.MutableInputs()
->Tag(kNormRectTag)
.packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -109,12 +119,20 @@ INSTANTIATE_TEST_CASE_P(
.base_offset = 0, .base_offset = 0,
.object_normalization_origin_offset = 0, .object_normalization_origin_offset = 0,
.expected_cell_0_2 = 0.1f, .expected_cell_0_2 = 0.1f,
.expected_cell_1_5 = 0.1875f}, .expected_cell_1_5 = 0.1875f,
.rotation = 0},
{.test_name = "TestWithOffset21", {.test_name = "TestWithOffset21",
.base_offset = 21, .base_offset = 21,
.object_normalization_origin_offset = 0, .object_normalization_origin_offset = 0,
.expected_cell_0_2 = 0.1f, .expected_cell_0_2 = 0.1f,
.expected_cell_1_5 = 0.1875f}}), .expected_cell_1_5 = 0.1875f,
.rotation = 0},
{.test_name = "TestWithRotation",
.base_offset = 0,
.object_normalization_origin_offset = 0,
.expected_cell_0_2 = 0.075f,
.expected_cell_1_5 = -0.25f,
.rotation = M_PI / 2.0}}),
[](const testing::TestParamInfo< [](const testing::TestParamInfo<
Landmarks2dToMatrixCalculatorTest::ParamType>& info) { Landmarks2dToMatrixCalculatorTest::ParamType>& info) {
return info.param.test_name; return info.param.test_name;
@ -126,6 +144,7 @@ struct LandmarksWorld3dToMatrixCalculatorTestCase {
int object_normalization_origin_offset = -1; int object_normalization_origin_offset = -1;
float expected_cell_0_2; float expected_cell_0_2;
float expected_cell_1_5; float expected_cell_1_5;
float rotation;
}; };
using LandmarksWorld3dToMatrixCalculatorTest = using LandmarksWorld3dToMatrixCalculatorTest =
@ -140,6 +159,7 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) {
calculator: "LandmarksToMatrixCalculator" calculator: "LandmarksToMatrixCalculator"
input_stream: "WORLD_LANDMARKS:landmarks" input_stream: "WORLD_LANDMARKS:landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
input_stream: "NORM_RECT:norm_rect"
output_stream: "LANDMARKS_MATRIX:landmarks_matrix" output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
options { options {
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] { [mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
@ -162,6 +182,11 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) {
runner.MutableInputs() runner.MutableInputs()
->Tag(kImageSizeTag) ->Tag(kImageSizeTag)
.packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); .packets.push_back(Adopt(image_size.release()).At(Timestamp(0)));
auto norm_rect = std::make_unique<NormalizedRect>();
norm_rect->set_rotation(test_case.rotation);
runner.MutableInputs()
->Tag(kNormRectTag)
.packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -180,17 +205,26 @@ INSTANTIATE_TEST_CASE_P(
.base_offset = 0, .base_offset = 0,
.object_normalization_origin_offset = 0, .object_normalization_origin_offset = 0,
.expected_cell_0_2 = 0.1f, .expected_cell_0_2 = 0.1f,
.expected_cell_1_5 = 0.25}, .expected_cell_1_5 = 0.25,
.rotation = 0},
{.test_name = "TestWithOffset21", {.test_name = "TestWithOffset21",
.base_offset = 21, .base_offset = 21,
.object_normalization_origin_offset = 0, .object_normalization_origin_offset = 0,
.expected_cell_0_2 = 0.1f, .expected_cell_0_2 = 0.1f,
.expected_cell_1_5 = 0.25}, .expected_cell_1_5 = 0.25,
.rotation = 0},
{.test_name = "NoObjectNormalization", {.test_name = "NoObjectNormalization",
.base_offset = 0, .base_offset = 0,
.object_normalization_origin_offset = -1, .object_normalization_origin_offset = -1,
.expected_cell_0_2 = 0.021f, .expected_cell_0_2 = 0.021f,
.expected_cell_1_5 = 0.052f}}), .expected_cell_1_5 = 0.052f,
.rotation = 0},
{.test_name = "TestWithRotation",
.base_offset = 0,
.object_normalization_origin_offset = 0,
.expected_cell_0_2 = 0.1f,
.expected_cell_1_5 = -0.25f,
.rotation = M_PI / 2.0}}),
[](const testing::TestParamInfo< [](const testing::TestParamInfo<
LandmarksWorld3dToMatrixCalculatorTest::ParamType>& info) { LandmarksWorld3dToMatrixCalculatorTest::ParamType>& info) {
return info.param.test_name; return info.param.test_name;

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <utility>
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
@ -27,6 +28,7 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
@ -62,6 +64,8 @@ constexpr char kHandGestureSubgraphTypeName[] =
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageInStreamName[] = "image_in";
constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageOutStreamName[] = "image_out";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectStreamName[] = "norm_rect_in";
constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandGesturesStreamName[] = "hand_gestures"; constexpr char kHandGesturesStreamName[] = "hand_gestures";
constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kHandednessTag[] = "HANDEDNESS";
@ -72,6 +76,31 @@ constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
// Returns a NormalizedRect filling the whole image. If input is present, its
// rotation is set in the returned NormalizedRect and a check is performed to
// make sure no region-of-interest was provided. Otherwise, rotation is set to
// 0.
absl::StatusOr<NormalizedRect> FillNormalizedRect(
std::optional<NormalizedRect> normalized_rect) {
NormalizedRect result;
if (normalized_rect.has_value()) {
result = *normalized_rect;
}
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
result.has_width() || result.has_height();
if (has_coordinates) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"GestureRecognizer does not support region-of-interest.",
MediaPipeTasksStatus::kInvalidArgumentError);
}
result.set_x_center(0.5);
result.set_y_center(0.5);
result.set_width(1);
result.set_height(1);
return result;
}
// Creates a MediaPipe graph config that contains a subgraph node of // Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running // "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running
// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the // in the live stream mode, a "FlowLimiterCalculator" will be added to limit the
@ -83,6 +112,7 @@ CalculatorGraphConfig CreateGraphConfig(
auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName); auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName);
subgraph.GetOptions<GestureRecognizerGraphOptionsProto>().Swap(options.get()); subgraph.GetOptions<GestureRecognizerGraphOptionsProto>().Swap(options.get());
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectStreamName);
subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >> subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >>
graph.Out(kHandGesturesTag); graph.Out(kHandGesturesTag);
subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >> subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >>
@ -93,10 +123,11 @@ CalculatorGraphConfig CreateGraphConfig(
graph.Out(kHandWorldLandmarksTag); graph.Out(kHandWorldLandmarksTag);
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
if (enable_flow_limiting) { if (enable_flow_limiting) {
return tasks::core::AddFlowLimiterCalculator(graph, subgraph, {kImageTag}, return tasks::core::AddFlowLimiterCalculator(
kHandGesturesTag); graph, subgraph, {kImageTag, kNormRectTag}, kHandGesturesTag);
} }
graph.In(kImageTag) >> subgraph.In(kImageTag); graph.In(kImageTag) >> subgraph.In(kImageTag);
graph.In(kNormRectTag) >> subgraph.In(kNormRectTag);
return graph.GetConfig(); return graph.GetConfig();
} }
@ -216,16 +247,22 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
} }
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize( absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
mediapipe::Image image) { mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
"GPU input images are currently not supported.", "GPU input images are currently not supported.",
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(auto output_packets, ASSIGN_OR_RETURN(NormalizedRect norm_rect,
ProcessImageData({{kImageInStreamName, FillNormalizedRect(image_processing_options));
MakePacket<Image>(std::move(image))}})); ASSIGN_OR_RETURN(
auto output_packets,
ProcessImageData(
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
if (output_packets[kHandGesturesStreamName].IsEmpty()) { if (output_packets[kHandGesturesStreamName].IsEmpty()) {
return {{{}, {}, {}, {}}}; return {{{}, {}, {}, {}}};
} }
@ -245,18 +282,24 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
} }
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo( absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
mediapipe::Image image, int64 timestamp_ms) { mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
FillNormalizedRect(image_processing_options));
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_packets, auto output_packets,
ProcessVideoData( ProcessVideoData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
if (output_packets[kHandGesturesStreamName].IsEmpty()) { if (output_packets[kHandGesturesStreamName].IsEmpty()) {
return {{{}, {}, {}, {}}}; return {{{}, {}, {}, {}}};
@ -276,17 +319,23 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
}; };
} }
absl::Status GestureRecognizer::RecognizeAsync(mediapipe::Image image, absl::Status GestureRecognizer::RecognizeAsync(
int64 timestamp_ms) { mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> image_processing_options) {
if (image.UsesGpu()) { if (image.UsesGpu()) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
absl::StrCat("GPU input images are currently not supported."), absl::StrCat("GPU input images are currently not supported."),
MediaPipeTasksStatus::kRunnerUnexpectedInputError); MediaPipeTasksStatus::kRunnerUnexpectedInputError);
} }
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
FillNormalizedRect(image_processing_options));
return SendLiveStreamData( return SendLiveStreamData(
{{kImageInStreamName, {{kImageInStreamName,
MakePacket<Image>(std::move(image)) MakePacket<Image>(std::move(image))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
{kNormRectStreamName,
MakePacket<NormalizedRect>(std::move(norm_rect))
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }

View File

@ -17,11 +17,13 @@ limitations under the License.
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ #define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_
#include <memory> #include <memory>
#include <optional>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" #include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
@ -57,7 +59,7 @@ struct GestureRecognizerOptions {
int num_hands = 1; int num_hands = 1;
// The minimum confidence score for the hand detection to be considered // The minimum confidence score for the hand detection to be considered
// successfully. // successful.
float min_hand_detection_confidence = 0.5; float min_hand_detection_confidence = 0.5;
// The minimum confidence score of hand presence score in the hand landmark // The minimum confidence score of hand presence score in the hand landmark
@ -65,11 +67,11 @@ struct GestureRecognizerOptions {
float min_hand_presence_confidence = 0.5; float min_hand_presence_confidence = 0.5;
// The minimum confidence score for the hand tracking to be considered // The minimum confidence score for the hand tracking to be considered
// successfully. // successful.
float min_tracking_confidence = 0.5; float min_tracking_confidence = 0.5;
// The minimum confidence score for the gestures to be considered // The minimum confidence score for the gestures to be considered
// successfully. If < 0, the gesture confidence thresholds in the model // successful. If < 0, the gesture confidence thresholds in the model
// metadata are used. // metadata are used.
// TODO Note this option is subject to change, after scoring // TODO Note this option is subject to change, after scoring
// merging calculator is implemented. // merging calculator is implemented.
@ -93,6 +95,13 @@ struct GestureRecognizerOptions {
// Inputs: // Inputs:
// Image // Image
// - The image that gesture recognition runs on. // - The image that gesture recognition runs on.
// std::optional<NormalizedRect>
// - If provided, can be used to specify the rotation to apply to the image
// before performing gesture recognition, by setting its 'rotation' field
// in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note
// that specifying a region-of-interest using the 'x_center', 'y_center',
// 'width' and 'height' fields is NOT supported and will result in an
// invalid argument error being returned.
// Outputs: // Outputs:
// GestureRecognitionResult // GestureRecognitionResult
// - The hand gesture recognition results. // - The hand gesture recognition results.
@ -122,12 +131,23 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// //
// image - mediapipe::Image // image - mediapipe::Image
// Image to perform hand gesture recognition on. // Image to perform hand gesture recognition on.
// imageProcessingOptions - std::optional<NormalizedRect>
// If provided, can be used to specify the rotation to apply to the image
// before performing classification, by setting its 'rotation' field in
// radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note that
// specifying a region-of-interest using the 'x_center', 'y_center', 'width'
// and 'height' fields is NOT supported and will result in an invalid
// argument error being returned.
// //
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: Describes how the input image will be preprocessed // TODO: Describes how the input image will be preprocessed
// after the yuv support is implemented. // after the yuv support is implemented.
// TODO: use an ImageProcessingOptions struct instead of
// NormalizedRect.
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize( absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
Image image); Image image,
std::optional<mediapipe::NormalizedRect> image_processing_options =
std::nullopt);
// Performs gesture recognition on the provided video frame. // Performs gesture recognition on the provided video frame.
// Only use this method when the GestureRecognizer is created with the video // Only use this method when the GestureRecognizer is created with the video
@ -137,7 +157,9 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<components::containers::GestureRecognitionResult> absl::StatusOr<components::containers::GestureRecognitionResult>
RecognizeForVideo(Image image, int64 timestamp_ms); RecognizeForVideo(Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect>
image_processing_options = std::nullopt);
// Sends live image data to perform gesture recognition, and the results will // Sends live image data to perform gesture recognition, and the results will
// be available via the "result_callback" provided in the // be available via the "result_callback" provided in the
@ -157,7 +179,9 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
// longer be valid when the callback returns. To access the image data // longer be valid when the callback returns. To access the image data
// outside of the callback, callers need to make a copy of the image. // outside of the callback, callers need to make a copy of the image.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status RecognizeAsync(Image image, int64 timestamp_ms); absl::Status RecognizeAsync(Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect>
image_processing_options = std::nullopt);
// Shuts down the GestureRecognizer when all works are done. // Shuts down the GestureRecognizer when all works are done.
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
@ -53,6 +54,7 @@ using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarkerGraphOptions; HandLandmarkerGraphOptions;
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kHandednessTag[] = "HANDEDNESS";
@ -76,6 +78,9 @@ struct GestureRecognizerOutputs {
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
// Image to perform hand gesture recognition on. // Image to perform hand gesture recognition on.
// NORM_RECT - NormalizedRect
// Describes image rotation and region of image to perform landmarks
// detection on.
// //
// Outputs: // Outputs:
// HAND_GESTURES - std::vector<ClassificationList> // HAND_GESTURES - std::vector<ClassificationList>
@ -93,13 +98,15 @@ struct GestureRecognizerOutputs {
// IMAGE - mediapipe::Image // IMAGE - mediapipe::Image
// The image that gesture recognizer runs on and has the pixel data stored // The image that gesture recognizer runs on and has the pixel data stored
// on the target storage (CPU vs GPU). // on the target storage (CPU vs GPU).
// // All returned coordinates are in the unrotated and uncropped input image
// coordinates system.
// //
// Example: // Example:
// node { // node {
// calculator: // calculator:
// "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" // "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "HAND_GESTURES:hand_gestures" // output_stream: "HAND_GESTURES:hand_gestures"
// output_stream: "LANDMARKS:hand_landmarks" // output_stream: "LANDMARKS:hand_landmarks"
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" // output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
@ -132,7 +139,8 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
BuildGestureRecognizerGraph( BuildGestureRecognizerGraph(
*sc->MutableOptions<GestureRecognizerGraphOptions>(), *sc->MutableOptions<GestureRecognizerGraphOptions>(),
graph[Input<Image>(kImageTag)], graph)); graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
hand_gesture_recognition_output.gesture >> hand_gesture_recognition_output.gesture >>
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)]; graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
hand_gesture_recognition_output.handedness >> hand_gesture_recognition_output.handedness >>
@ -148,7 +156,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
private: private:
absl::StatusOr<GestureRecognizerOutputs> BuildGestureRecognizerGraph( absl::StatusOr<GestureRecognizerOutputs> BuildGestureRecognizerGraph(
GestureRecognizerGraphOptions& graph_options, Source<Image> image_in, GestureRecognizerGraphOptions& graph_options, Source<Image> image_in,
Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
auto& image_property = graph.AddNode("ImagePropertiesCalculator"); auto& image_property = graph.AddNode("ImagePropertiesCalculator");
image_in >> image_property.In("IMAGE"); image_in >> image_property.In("IMAGE");
auto image_size = image_property.Out("SIZE"); auto image_size = image_property.Out("SIZE");
@ -162,6 +170,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
graph_options.mutable_hand_landmarker_graph_options()); graph_options.mutable_hand_landmarker_graph_options());
image_in >> hand_landmarker_graph.In(kImageTag); image_in >> hand_landmarker_graph.In(kImageTag);
norm_rect_in >> hand_landmarker_graph.In(kNormRectTag);
auto hand_landmarks = auto hand_landmarks =
hand_landmarker_graph[Output<std::vector<NormalizedLandmarkList>>( hand_landmarker_graph[Output<std::vector<NormalizedLandmarkList>>(
kLandmarksTag)]; kLandmarksTag)];
@ -187,6 +196,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag); hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag);
handedness >> hand_gesture_subgraph.In(kHandednessTag); handedness >> hand_gesture_subgraph.In(kHandednessTag);
image_size >> hand_gesture_subgraph.In(kImageSizeTag); image_size >> hand_gesture_subgraph.In(kImageSizeTag);
norm_rect_in >> hand_gesture_subgraph.In(kNormRectTag);
hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag); hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag);
auto hand_gestures = auto hand_gestures =
hand_gesture_subgraph[Output<std::vector<ClassificationList>>( hand_gesture_subgraph[Output<std::vector<ClassificationList>>(

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
@ -57,6 +58,7 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX";
@ -92,6 +94,9 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
// Detected hand landmarks in world coordinates. // Detected hand landmarks in world coordinates.
// IMAGE_SIZE - std::pair<int, int> // IMAGE_SIZE - std::pair<int, int>
// The size of image from which the landmarks detected from. // The size of image from which the landmarks detected from.
// NORM_RECT - NormalizedRect
// NormalizedRect whose 'rotation' field is used to rotate the
// landmarks before processing them.
// //
// Outputs: // Outputs:
// HAND_GESTURES - ClassificationList // HAND_GESTURES - ClassificationList
@ -106,6 +111,7 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
// input_stream: "LANDMARKS:landmarks" // input_stream: "LANDMARKS:landmarks"
// input_stream: "WORLD_LANDMARKS:world_landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks"
// input_stream: "IMAGE_SIZE:image_size" // input_stream: "IMAGE_SIZE:image_size"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "HAND_GESTURES:hand_gestures" // output_stream: "HAND_GESTURES:hand_gestures"
// options { // options {
// [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext] // [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext]
@ -133,7 +139,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<ClassificationList>(kHandednessTag)], graph[Input<ClassificationList>(kHandednessTag)],
graph[Input<NormalizedLandmarkList>(kLandmarksTag)], graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
graph[Input<LandmarkList>(kWorldLandmarksTag)], graph[Input<LandmarkList>(kWorldLandmarksTag)],
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph)); graph[Input<std::pair<int, int>>(kImageSizeTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)]; hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
return graph.GetConfig(); return graph.GetConfig();
} }
@ -145,7 +152,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
Source<ClassificationList> handedness, Source<ClassificationList> handedness,
Source<NormalizedLandmarkList> hand_landmarks, Source<NormalizedLandmarkList> hand_landmarks,
Source<LandmarkList> hand_world_landmarks, Source<LandmarkList> hand_world_landmarks,
Source<std::pair<int, int>> image_size, Graph& graph) { Source<std::pair<int, int>> image_size, Source<NormalizedRect> norm_rect,
Graph& graph) {
// Converts the ClassificationList to a matrix. // Converts the ClassificationList to a matrix.
auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator"); auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator");
handedness >> handedness_to_matrix.In(kHandednessTag); handedness >> handedness_to_matrix.In(kHandednessTag);
@ -166,6 +174,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
landmarks_options; landmarks_options;
hand_landmarks >> hand_landmarks_to_matrix.In(kLandmarksTag); hand_landmarks >> hand_landmarks_to_matrix.In(kLandmarksTag);
image_size >> hand_landmarks_to_matrix.In(kImageSizeTag); image_size >> hand_landmarks_to_matrix.In(kImageSizeTag);
norm_rect >> hand_landmarks_to_matrix.In(kNormRectTag);
auto hand_landmarks_matrix = auto hand_landmarks_matrix =
hand_landmarks_to_matrix[Output<Matrix>(kLandmarksMatrixTag)]; hand_landmarks_to_matrix[Output<Matrix>(kLandmarksMatrixTag)];
@ -181,6 +190,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
hand_world_landmarks >> hand_world_landmarks >>
hand_world_landmarks_to_matrix.In(kWorldLandmarksTag); hand_world_landmarks_to_matrix.In(kWorldLandmarksTag);
image_size >> hand_world_landmarks_to_matrix.In(kImageSizeTag); image_size >> hand_world_landmarks_to_matrix.In(kImageSizeTag);
norm_rect >> hand_world_landmarks_to_matrix.In(kNormRectTag);
auto hand_world_landmarks_matrix = auto hand_world_landmarks_matrix =
hand_world_landmarks_to_matrix[Output<Matrix>(kLandmarksMatrixTag)]; hand_world_landmarks_to_matrix[Output<Matrix>(kLandmarksMatrixTag)];
@ -239,6 +249,9 @@ REGISTER_MEDIAPIPE_GRAPH(
// A vector hand landmarks in world coordinates. // A vector hand landmarks in world coordinates.
// IMAGE_SIZE - std::pair<int, int> // IMAGE_SIZE - std::pair<int, int>
// The size of image from which the landmarks detected from. // The size of image from which the landmarks detected from.
// NORM_RECT - NormalizedRect
// NormalizedRect whose 'rotation' field is used to rotate the
// landmarks before processing them.
// HAND_TRACKING_IDS - std::vector<int> // HAND_TRACKING_IDS - std::vector<int>
// A vector of the tracking ids of the hands. The tracking id is the vector // A vector of the tracking ids of the hands. The tracking id is the vector
// index corresponding to the same hand if the graph runs multiple times. // index corresponding to the same hand if the graph runs multiple times.
@ -257,6 +270,7 @@ REGISTER_MEDIAPIPE_GRAPH(
// input_stream: "LANDMARKS:landmarks" // input_stream: "LANDMARKS:landmarks"
// input_stream: "WORLD_LANDMARKS:world_landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks"
// input_stream: "IMAGE_SIZE:image_size" // input_stream: "IMAGE_SIZE:image_size"
// input_stream: "NORM_RECT:norm_rect"
// input_stream: "HAND_TRACKING_IDS:hand_tracking_ids" // input_stream: "HAND_TRACKING_IDS:hand_tracking_ids"
// output_stream: "HAND_GESTURES:hand_gestures" // output_stream: "HAND_GESTURES:hand_gestures"
// options { // options {
@ -283,6 +297,7 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
graph[Input<std::vector<NormalizedLandmarkList>>(kLandmarksTag)], graph[Input<std::vector<NormalizedLandmarkList>>(kLandmarksTag)],
graph[Input<std::vector<LandmarkList>>(kWorldLandmarksTag)], graph[Input<std::vector<LandmarkList>>(kWorldLandmarksTag)],
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph[Input<std::pair<int, int>>(kImageSizeTag)],
graph[Input<NormalizedRect>(kNormRectTag)],
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph)); graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
multi_hand_gestures >> multi_hand_gestures >>
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)]; graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
@ -296,18 +311,20 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
Source<std::vector<ClassificationList>> multi_handedness, Source<std::vector<ClassificationList>> multi_handedness,
Source<std::vector<NormalizedLandmarkList>> multi_hand_landmarks, Source<std::vector<NormalizedLandmarkList>> multi_hand_landmarks,
Source<std::vector<LandmarkList>> multi_hand_world_landmarks, Source<std::vector<LandmarkList>> multi_hand_world_landmarks,
Source<std::pair<int, int>> image_size, Source<std::pair<int, int>> image_size, Source<NormalizedRect> norm_rect,
Source<std::vector<int>> multi_hand_tracking_ids, Graph& graph) { Source<std::vector<int>> multi_hand_tracking_ids, Graph& graph) {
auto& begin_loop_int = graph.AddNode("BeginLoopIntCalculator"); auto& begin_loop_int = graph.AddNode("BeginLoopIntCalculator");
image_size >> begin_loop_int.In(kCloneTag)[0]; image_size >> begin_loop_int.In(kCloneTag)[0];
multi_handedness >> begin_loop_int.In(kCloneTag)[1]; norm_rect >> begin_loop_int.In(kCloneTag)[1];
multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[2]; multi_handedness >> begin_loop_int.In(kCloneTag)[2];
multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[3]; multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[3];
multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[4];
multi_hand_tracking_ids >> begin_loop_int.In(kIterableTag); multi_hand_tracking_ids >> begin_loop_int.In(kIterableTag);
auto image_size_clone = begin_loop_int.Out(kCloneTag)[0]; auto image_size_clone = begin_loop_int.Out(kCloneTag)[0];
auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[1]; auto norm_rect_clone = begin_loop_int.Out(kCloneTag)[1];
auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[2]; auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[2];
auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[3]; auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[3];
auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[4];
auto hand_tracking_id = begin_loop_int.Out(kItemTag); auto hand_tracking_id = begin_loop_int.Out(kItemTag);
auto batch_end = begin_loop_int.Out(kBatchEndTag); auto batch_end = begin_loop_int.Out(kBatchEndTag);
@ -341,6 +358,7 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
hand_world_landmarks >> hand_world_landmarks >>
hand_gesture_recognizer_graph.In(kWorldLandmarksTag); hand_gesture_recognizer_graph.In(kWorldLandmarksTag);
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
norm_rect_clone >> hand_gesture_recognizer_graph.In(kNormRectTag);
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
auto& end_loop_classification_lists = auto& end_loop_classification_lists =

View File

@ -32,7 +32,7 @@ cc_library(
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator", "//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
"//mediapipe/calculators/util:detection_letterbox_removal_calculator", "//mediapipe/calculators/util:detection_projection_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
"//mediapipe/calculators/util:non_max_suppression_calculator", "//mediapipe/calculators/util:non_max_suppression_calculator",

View File

@ -58,6 +58,7 @@ using ::mediapipe::tasks::vision::hand_detector::proto::
HandDetectorGraphOptions; HandDetectorGraphOptions;
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kHandRectsTag[] = "HAND_RECTS"; constexpr char kHandRectsTag[] = "HAND_RECTS";
constexpr char kPalmRectsTag[] = "PALM_RECTS"; constexpr char kPalmRectsTag[] = "PALM_RECTS";
@ -148,6 +149,9 @@ void ConfigureRectTransformationCalculator(
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
// Image to perform detection on. // Image to perform detection on.
// NORM_RECT - NormalizedRect
// Describes image rotation and region of image to perform detection
// on.
// //
// Outputs: // Outputs:
// PALM_DETECTIONS - std::vector<Detection> // PALM_DETECTIONS - std::vector<Detection>
@ -159,11 +163,14 @@ void ConfigureRectTransformationCalculator(
// IMAGE - Image // IMAGE - Image
// The input image that the hand detector runs on and has the pixel data // The input image that the hand detector runs on and has the pixel data
// stored on the target storage (CPU vs GPU). // stored on the target storage (CPU vs GPU).
// All returned coordinates are in the unrotated and uncropped input image
// coordinates system.
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" // calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph"
// input_stream: "IMAGE:image" // input_stream: "IMAGE:image"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "PALM_DETECTIONS:palm_detections" // output_stream: "PALM_DETECTIONS:palm_detections"
// output_stream: "HAND_RECTS:hand_rects_from_palm_detections" // output_stream: "HAND_RECTS:hand_rects_from_palm_detections"
// output_stream: "PALM_RECTS:palm_rects" // output_stream: "PALM_RECTS:palm_rects"
@ -189,11 +196,11 @@ class HandDetectorGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<HandDetectorGraphOptions>(sc)); CreateModelResources<HandDetectorGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(auto hand_detection_outs,
auto hand_detection_outs, BuildHandDetectionSubgraph(
BuildHandDetectionSubgraph(sc->Options<HandDetectorGraphOptions>(), sc->Options<HandDetectorGraphOptions>(),
*model_resources, *model_resources, graph[Input<Image>(kImageTag)],
graph[Input<Image>(kImageTag)], graph)); graph[Input<NormalizedRect>(kNormRectTag)], graph));
hand_detection_outs.palm_detections >> hand_detection_outs.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)]; graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_detection_outs.hand_rects >> hand_detection_outs.hand_rects >>
@ -216,7 +223,7 @@ class HandDetectorGraph : public core::ModelTaskGraph {
absl::StatusOr<HandDetectionOuts> BuildHandDetectionSubgraph( absl::StatusOr<HandDetectionOuts> BuildHandDetectionSubgraph(
const HandDetectorGraphOptions& subgraph_options, const HandDetectorGraphOptions& subgraph_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Add image preprocessing subgraph. The model expects aspect ratio // Add image preprocessing subgraph. The model expects aspect ratio
// unchanged. // unchanged.
auto& preprocessing = auto& preprocessing =
@ -233,8 +240,9 @@ class HandDetectorGraph : public core::ModelTaskGraph {
&preprocessing &preprocessing
.GetOptions<tasks::components::ImagePreprocessingOptions>())); .GetOptions<tasks::components::ImagePreprocessingOptions>()));
image_in >> preprocessing.In("IMAGE"); image_in >> preprocessing.In("IMAGE");
norm_rect_in >> preprocessing.In("NORM_RECT");
auto preprocessed_tensors = preprocessing.Out("TENSORS"); auto preprocessed_tensors = preprocessing.Out("TENSORS");
auto letterbox_padding = preprocessing.Out("LETTERBOX_PADDING"); auto matrix = preprocessing.Out("MATRIX");
auto image_size = preprocessing.Out("IMAGE_SIZE"); auto image_size = preprocessing.Out("IMAGE_SIZE");
// Adds SSD palm detection model. // Adds SSD palm detection model.
@ -278,17 +286,12 @@ class HandDetectorGraph : public core::ModelTaskGraph {
nms_detections >> detection_label_id_to_text.In(""); nms_detections >> detection_label_id_to_text.In("");
auto detections_with_text = detection_label_id_to_text.Out(""); auto detections_with_text = detection_label_id_to_text.Out("");
// Adjusts detection locations (already normalized to [0.f, 1.f]) on the // Projects detections back into the input image coordinates system.
// letterboxed image (after image transformation with the FIT scale mode) to auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
// the corresponding locations on the same image with the letterbox removed detections_with_text >> detection_projection.In("DETECTIONS");
// (the input image to the graph before image transformation). matrix >> detection_projection.In("PROJECTION_MATRIX");
auto& detection_letterbox_removal =
graph.AddNode("DetectionLetterboxRemovalCalculator");
detections_with_text >> detection_letterbox_removal.In("DETECTIONS");
letterbox_padding >> detection_letterbox_removal.In("LETTERBOX_PADDING");
auto palm_detections = auto palm_detections =
detection_letterbox_removal[Output<std::vector<Detection>>( detection_projection[Output<std::vector<Detection>>("DETECTIONS")];
"DETECTIONS")];
// Converts each palm detection into a rectangle (normalized by image size) // Converts each palm detection into a rectangle (normalized by image size)
// that encloses the palm and is rotated such that the line connecting // that encloses the palm and is rotated such that the line connecting

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cmath>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <string> #include <string>
@ -75,13 +76,18 @@ using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite";
constexpr char kTestRightHandsImage[] = "right_hands.jpg"; constexpr char kTestRightHandsImage[] = "right_hands.jpg";
constexpr char kTestRightHandsRotatedImage[] = "right_hands_rotated.jpg";
constexpr char kTestModelResourcesTag[] = "test_model_resources"; constexpr char kTestModelResourcesTag[] = "test_model_resources";
constexpr char kOneHandResultFile[] = "hand_detector_result_one_hand.pbtxt"; constexpr char kOneHandResultFile[] = "hand_detector_result_one_hand.pbtxt";
constexpr char kOneHandRotatedResultFile[] =
"hand_detector_result_one_hand_rotated.pbtxt";
constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt"; constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image"; constexpr char kImageName[] = "image";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectName[] = "norm_rect";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kPalmDetectionsName[] = "palm_detections"; constexpr char kPalmDetectionsName[] = "palm_detections";
constexpr char kHandRectsTag[] = "HAND_RECTS"; constexpr char kHandRectsTag[] = "HAND_RECTS";
@ -117,6 +123,8 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
graph[Input<Image>(kImageTag)].SetName(kImageName) >> graph[Input<Image>(kImageTag)].SetName(kImageName) >>
hand_detection.In(kImageTag); hand_detection.In(kImageTag);
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
hand_detection.In(kNormRectTag);
hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >> hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)]; graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
@ -142,6 +150,9 @@ struct TestParams {
std::string hand_detection_model_name; std::string hand_detection_model_name;
// The filename of test image. // The filename of test image.
std::string test_image_name; std::string test_image_name;
// The rotation to apply to the test image before processing, in radians
// counter-clockwise.
float rotation;
// The number of maximum detected hands. // The number of maximum detected hands.
int num_hands; int num_hands;
// The expected hand detector result. // The expected hand detector result.
@ -154,14 +165,22 @@ TEST_P(HandDetectionTest, DetectTwoHands) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
GetParam().test_image_name))); GetParam().test_image_name)));
NormalizedRect input_norm_rect;
input_norm_rect.set_rotation(GetParam().rotation);
input_norm_rect.set_x_center(0.5);
input_norm_rect.set_y_center(0.5);
input_norm_rect.set_width(1.0);
input_norm_rect.set_height(1.0);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(GetParam().hand_detection_model_name)); CreateModelResourcesForModel(GetParam().hand_detection_model_name));
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto task_runner, CreateTaskRunner(*model_resources, kPalmDetectionModel, auto task_runner, CreateTaskRunner(*model_resources, kPalmDetectionModel,
GetParam().num_hands)); GetParam().num_hands));
auto output_packets = auto output_packets = task_runner->Process(
task_runner->Process({{kImageName, MakePacket<Image>(std::move(image))}}); {{kImageName, MakePacket<Image>(std::move(image))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
MP_ASSERT_OK(output_packets); MP_ASSERT_OK(output_packets);
const std::vector<Detection>& palm_detections = const std::vector<Detection>& palm_detections =
(*output_packets)[kPalmDetectionsName].Get<std::vector<Detection>>(); (*output_packets)[kPalmDetectionsName].Get<std::vector<Detection>>();
@ -188,15 +207,24 @@ INSTANTIATE_TEST_SUITE_P(
Values(TestParams{.test_name = "DetectOneHand", Values(TestParams{.test_name = "DetectOneHand",
.hand_detection_model_name = kPalmDetectionModel, .hand_detection_model_name = kPalmDetectionModel,
.test_image_name = kTestRightHandsImage, .test_image_name = kTestRightHandsImage,
.rotation = 0,
.num_hands = 1, .num_hands = 1,
.expected_result = .expected_result =
GetExpectedHandDetectorResult(kOneHandResultFile)}, GetExpectedHandDetectorResult(kOneHandResultFile)},
TestParams{.test_name = "DetectTwoHands", TestParams{.test_name = "DetectTwoHands",
.hand_detection_model_name = kPalmDetectionModel, .hand_detection_model_name = kPalmDetectionModel,
.test_image_name = kTestRightHandsImage, .test_image_name = kTestRightHandsImage,
.rotation = 0,
.num_hands = 2, .num_hands = 2,
.expected_result = .expected_result =
GetExpectedHandDetectorResult(kTwoHandsResultFile)}), GetExpectedHandDetectorResult(kTwoHandsResultFile)},
TestParams{.test_name = "DetectOneHandWithRotation",
.hand_detection_model_name = kPalmDetectionModel,
.test_image_name = kTestRightHandsRotatedImage,
.rotation = M_PI / 2.0f,
.num_hands = 1,
.expected_result = GetExpectedHandDetectorResult(
kOneHandRotatedResultFile)}),
[](const TestParamInfo<HandDetectionTest::ParamType>& info) { [](const TestParamInfo<HandDetectionTest::ParamType>& info) {
return info.param.test_name; return info.param.test_name;
}); });

View File

@ -91,10 +91,14 @@ cc_library(
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/utils:gate", "//mediapipe/tasks/cc/components/utils:gate",
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
"//mediapipe/tasks/cc/core:model_resources_cache",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",

View File

@ -57,7 +57,7 @@ cc_library(
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components/containers:landmarks_detection", "//mediapipe/tasks/cc/components/containers:rect",
"//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder", "//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder",
"//mediapipe/tasks/cc/vision/utils:landmarks_utils", "//mediapipe/tasks/cc/vision/utils:landmarks_utils",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",

View File

@ -34,7 +34,7 @@ limitations under the License.
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h" #include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h"
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" #include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
@ -44,7 +44,7 @@ namespace {
using ::mediapipe::api2::Input; using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::Bound; using ::mediapipe::tasks::components::containers::Rect;
using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::CalculateIOU;
using ::mediapipe::tasks::vision::utils::DuplicatesFinder; using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
@ -126,7 +126,7 @@ absl::StatusOr<float> HandBaselineDistance(
return distance; return distance;
} }
Bound CalculateBound(const NormalizedLandmarkList& list) { Rect CalculateBound(const NormalizedLandmarkList& list) {
constexpr float kMinInitialValue = std::numeric_limits<float>::max(); constexpr float kMinInitialValue = std::numeric_limits<float>::max();
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest(); constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder {
const int num = multi_landmarks.size(); const int num = multi_landmarks.size();
std::vector<float> baseline_distances; std::vector<float> baseline_distances;
baseline_distances.reserve(num); baseline_distances.reserve(num);
std::vector<Bound> bounds; std::vector<Rect> bounds;
bounds.reserve(num); bounds.reserve(num);
for (const NormalizedLandmarkList& list : multi_landmarks) { for (const NormalizedLandmarkList& list : multi_landmarks) {
ASSIGN_OR_RETURN(const float baseline_distance, ASSIGN_OR_RETURN(const float baseline_distance,

View File

@ -29,10 +29,14 @@ limitations under the License.
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
@ -50,6 +54,8 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::utils::DisallowIf; using ::mediapipe::tasks::components::utils::DisallowIf;
using ::mediapipe::tasks::core::ModelAssetBundleResources;
using ::mediapipe::tasks::metadata::SetExternalFile;
using ::mediapipe::tasks::vision::hand_detector::proto:: using ::mediapipe::tasks::vision::hand_detector::proto::
HandDetectorGraphOptions; HandDetectorGraphOptions;
using ::mediapipe::tasks::vision::hand_landmarker::proto:: using ::mediapipe::tasks::vision::hand_landmarker::proto::
@ -58,6 +64,7 @@ using ::mediapipe::tasks::vision::hand_landmarker::proto::
HandLandmarksDetectorGraphOptions; HandLandmarksDetectorGraphOptions;
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
@ -65,6 +72,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kPalmRectsTag[] = "PALM_RECTS"; constexpr char kPalmRectsTag[] = "PALM_RECTS";
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator"; constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
constexpr char kHandDetectorTFLiteName[] = "hand_detector.tflite";
constexpr char kHandLandmarksDetectorTFLiteName[] =
"hand_landmarks_detector.tflite";
struct HandLandmarkerOutputs { struct HandLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> landmark_lists; Source<std::vector<NormalizedLandmarkList>> landmark_lists;
@ -76,6 +86,27 @@ struct HandLandmarkerOutputs {
Source<Image> image; Source<Image> image;
}; };
// Sets the base options in the sub tasks.
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
HandLandmarkerGraphOptions* options,
bool is_copy) {
ASSIGN_OR_RETURN(const auto hand_detector_file,
resources.GetModelFile(kHandDetectorTFLiteName));
SetExternalFile(hand_detector_file,
options->mutable_hand_detector_graph_options()
->mutable_base_options()
->mutable_model_asset(),
is_copy);
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
SetExternalFile(hand_landmarks_detector_file,
options->mutable_hand_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset(),
is_copy);
return absl::OkStatus();
}
} // namespace } // namespace
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand // A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
@ -92,6 +123,9 @@ struct HandLandmarkerOutputs {
// Inputs: // Inputs:
// IMAGE - Image // IMAGE - Image
// Image to perform hand landmarks detection on. // Image to perform hand landmarks detection on.
// NORM_RECT - NormalizedRect
// Describes image rotation and region of image to perform landmarks
// detection on.
// //
// Outputs: // Outputs:
// LANDMARKS: - std::vector<NormalizedLandmarkList> // LANDMARKS: - std::vector<NormalizedLandmarkList>
@ -110,11 +144,14 @@ struct HandLandmarkerOutputs {
// IMAGE - Image // IMAGE - Image
// The input image that the hand landmarker runs on and has the pixel data // The input image that the hand landmarker runs on and has the pixel data
// stored on the target storage (CPU vs GPU). // stored on the target storage (CPU vs GPU).
// All returned coordinates are in the unrotated and uncropped input image
// coordinates system.
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" // calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// input_stream: "NORM_RECT:norm_rect"
// output_stream: "LANDMARKS:hand_landmarks" // output_stream: "LANDMARKS:hand_landmarks"
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" // output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" // output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
@ -154,10 +191,25 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
Graph graph; Graph graph;
ASSIGN_OR_RETURN( if (sc->Options<HandLandmarkerGraphOptions>()
auto hand_landmarker_outputs, .base_options()
BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(), .has_model_asset()) {
graph[Input<Image>(kImageTag)], graph)); ASSIGN_OR_RETURN(
const auto* model_asset_bundle_resources,
CreateModelAssetBundleResources<HandLandmarkerGraphOptions>(sc));
// Copies the file content instead of passing the pointer of file in
// memory if the subgraph model resource service is not available.
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
*model_asset_bundle_resources,
sc->MutableOptions<HandLandmarkerGraphOptions>(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(auto hand_landmarker_outputs,
BuildHandLandmarkerGraph(
sc->Options<HandLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
hand_landmarker_outputs.landmark_lists >> hand_landmarker_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)]; graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_outputs.world_landmark_lists >> hand_landmarker_outputs.world_landmark_lists >>
@ -196,7 +248,7 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
// graph: the mediapipe graph instance to be updated. // graph: the mediapipe graph instance to be updated.
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerGraph( absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerGraph(
const HandLandmarkerGraphOptions& tasks_options, Source<Image> image_in, const HandLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
const int max_num_hands = const int max_num_hands =
tasks_options.hand_detector_graph_options().num_hands(); tasks_options.hand_detector_graph_options().num_hands();
@ -214,12 +266,15 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
auto image_for_hand_detector = auto image_for_hand_detector =
DisallowIf(image_in, has_enough_hands, graph); DisallowIf(image_in, has_enough_hands, graph);
auto norm_rect_in_for_hand_detector =
DisallowIf(norm_rect_in, has_enough_hands, graph);
auto& hand_detector = auto& hand_detector =
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom( hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
tasks_options.hand_detector_graph_options()); tasks_options.hand_detector_graph_options());
image_for_hand_detector >> hand_detector.In("IMAGE"); image_for_hand_detector >> hand_detector.In("IMAGE");
norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT");
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
auto& hand_association = graph.AddNode("HandAssociationCalculator"); auto& hand_association = graph.AddNode("HandAssociationCalculator");

View File

@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include <cmath>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector>
#include "absl/flags/flag.h" #include "absl/flags/flag.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
@ -65,12 +67,14 @@ using ::testing::proto::Approximately;
using ::testing::proto::Partially; using ::testing::proto::Partially;
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task";
constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite";
constexpr char kLeftHandsImage[] = "left_hands.jpg"; constexpr char kLeftHandsImage[] = "left_hands.jpg";
constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg";
constexpr char kImageTag[] = "IMAGE"; constexpr char kImageTag[] = "IMAGE";
constexpr char kImageName[] = "image_in"; constexpr char kImageName[] = "image_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kNormRectName[] = "norm_rect_in";
constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kLandmarksTag[] = "LANDMARKS";
constexpr char kLandmarksName[] = "landmarks"; constexpr char kLandmarksName[] = "landmarks";
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
@ -85,6 +89,11 @@ constexpr char kExpectedLeftUpHandLandmarksFilename[] =
"expected_left_up_hand_landmarks.prototxt"; "expected_left_up_hand_landmarks.prototxt";
constexpr char kExpectedLeftDownHandLandmarksFilename[] = constexpr char kExpectedLeftDownHandLandmarksFilename[] =
"expected_left_down_hand_landmarks.prototxt"; "expected_left_down_hand_landmarks.prototxt";
// Same but for the rotated image.
constexpr char kExpectedLeftUpHandRotatedLandmarksFilename[] =
"expected_left_up_hand_rotated_landmarks.prototxt";
constexpr char kExpectedLeftDownHandRotatedLandmarksFilename[] =
"expected_left_down_hand_rotated_landmarks.prototxt";
constexpr float kFullModelFractionDiff = 0.03; // percentage constexpr float kFullModelFractionDiff = 0.03; // percentage
constexpr float kAbsMargin = 0.03; constexpr float kAbsMargin = 0.03;
@ -105,21 +114,15 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"); "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
auto& options = auto& options =
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>(); hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
options.mutable_hand_detector_graph_options() options.mutable_base_options()->mutable_model_asset()->set_file_name(
->mutable_base_options() JoinPath("./", kTestDataDirectory, kHandLandmarkerModelBundle));
->mutable_model_asset()
->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel));
options.mutable_hand_detector_graph_options()->mutable_base_options();
options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands); options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands);
options.mutable_hand_landmarks_detector_graph_options()
->mutable_base_options()
->mutable_model_asset()
->set_file_name(
JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel));
options.set_min_tracking_confidence(kMinTrackingConfidence); options.set_min_tracking_confidence(kMinTrackingConfidence);
graph[Input<Image>(kImageTag)].SetName(kImageName) >> graph[Input<Image>(kImageTag)].SetName(kImageName) >>
hand_landmarker_graph.In(kImageTag); hand_landmarker_graph.In(kImageTag);
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
hand_landmarker_graph.In(kNormRectTag);
hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >> hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)]; graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >> hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >>
@ -139,9 +142,16 @@ TEST_F(HandLandmarkerTest, Succeeds) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
Image image, Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage))); DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage)));
NormalizedRect input_norm_rect;
input_norm_rect.set_x_center(0.5);
input_norm_rect.set_y_center(0.5);
input_norm_rect.set_width(1.0);
input_norm_rect.set_height(1.0);
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
auto output_packets = auto output_packets = task_runner->Process(
task_runner->Process({{kImageName, MakePacket<Image>(std::move(image))}}); {{kImageName, MakePacket<Image>(std::move(image))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
const auto& landmarks = (*output_packets)[kLandmarksName] const auto& landmarks = (*output_packets)[kLandmarksName]
.Get<std::vector<NormalizedLandmarkList>>(); .Get<std::vector<NormalizedLandmarkList>>();
ASSERT_EQ(landmarks.size(), kMaxNumHands); ASSERT_EQ(landmarks.size(), kMaxNumHands);
@ -159,6 +169,38 @@ TEST_F(HandLandmarkerTest, Succeeds) {
/*fraction=*/kFullModelFractionDiff)); /*fraction=*/kFullModelFractionDiff));
} }
TEST_F(HandLandmarkerTest, SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
kLeftHandsRotatedImage)));
NormalizedRect input_norm_rect;
input_norm_rect.set_x_center(0.5);
input_norm_rect.set_y_center(0.5);
input_norm_rect.set_width(1.0);
input_norm_rect.set_height(1.0);
input_norm_rect.set_rotation(M_PI / 2.0);
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
auto output_packets = task_runner->Process(
{{kImageName, MakePacket<Image>(std::move(image))},
{kNormRectName,
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
const auto& landmarks = (*output_packets)[kLandmarksName]
.Get<std::vector<NormalizedLandmarkList>>();
ASSERT_EQ(landmarks.size(), kMaxNumHands);
std::vector<NormalizedLandmarkList> expected_landmarks = {
GetExpectedLandmarkList(kExpectedLeftUpHandRotatedLandmarksFilename),
GetExpectedLandmarkList(kExpectedLeftDownHandRotatedLandmarksFilename)};
EXPECT_THAT(landmarks[0],
Approximately(Partially(EqualsProto(expected_landmarks[0])),
/*margin=*/kAbsMargin,
/*fraction=*/kFullModelFractionDiff));
EXPECT_THAT(landmarks[1],
Approximately(Partially(EqualsProto(expected_landmarks[1])),
/*margin=*/kAbsMargin,
/*fraction=*/kFullModelFractionDiff));
}
} // namespace } // namespace
} // namespace hand_landmarker } // namespace hand_landmarker

View File

@ -29,8 +29,8 @@ message HandLandmarkerGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional HandLandmarkerGraphOptions ext = 462713202; optional HandLandmarkerGraphOptions ext = 462713202;
} }
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite // Base options for configuring MediaPipe Tasks, such as specifying the model
// model file with metadata, accelerator options, etc. // asset bundle file with metadata, accelerator options, etc.
optional core.proto.BaseOptions base_options = 1; optional core.proto.BaseOptions base_options = 1;
// Options for hand detector graph. // Options for hand detector graph.

View File

@ -94,7 +94,7 @@ cc_library(
name = "landmarks_utils", name = "landmarks_utils",
srcs = ["landmarks_utils.cc"], srcs = ["landmarks_utils.cc"],
hdrs = ["landmarks_utils.h"], hdrs = ["landmarks_utils.h"],
deps = ["//mediapipe/tasks/cc/components/containers:landmarks_detection"], deps = ["//mediapipe/tasks/cc/components/containers:rect"],
) )
cc_test( cc_test(
@ -103,6 +103,6 @@ cc_test(
deps = [ deps = [
":landmarks_utils", ":landmarks_utils",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/components/containers:landmarks_detection", "//mediapipe/tasks/cc/components/containers:rect",
], ],
) )

View File

@ -18,15 +18,17 @@ limitations under the License.
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::vision::utils { namespace mediapipe::tasks::vision::utils {
using ::mediapipe::tasks::components::containers::Bound; using ::mediapipe::tasks::components::containers::Rect;
float CalculateArea(const Bound& bound) { float CalculateArea(const Rect& rect) {
return (bound.right - bound.left) * (bound.bottom - bound.top); return (rect.right - rect.left) * (rect.bottom - rect.top);
} }
float CalculateIntersectionArea(const Bound& a, const Bound& b) { float CalculateIntersectionArea(const Rect& a, const Rect& b) {
const float intersection_left = std::max<float>(a.left, b.left); const float intersection_left = std::max<float>(a.left, b.left);
const float intersection_top = std::max<float>(a.top, b.top); const float intersection_top = std::max<float>(a.top, b.top);
const float intersection_right = std::min<float>(a.right, b.right); const float intersection_right = std::min<float>(a.right, b.right);
@ -36,7 +38,7 @@ float CalculateIntersectionArea(const Bound& a, const Bound& b) {
std::max<float>(intersection_right - intersection_left, 0.0); std::max<float>(intersection_right - intersection_left, 0.0);
} }
float CalculateIOU(const Bound& a, const Bound& b) { float CalculateIOU(const Rect& a, const Rect& b) {
const float area_a = CalculateArea(a); const float area_a = CalculateArea(a);
const float area_b = CalculateArea(b); const float area_b = CalculateArea(b);
if (area_a <= 0 || area_b <= 0) return 0.0; if (area_a <= 0 || area_b <= 0) return 0.0;

View File

@ -22,20 +22,20 @@ limitations under the License.
#include <limits> #include <limits>
#include <vector> #include <vector>
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h" #include "mediapipe/tasks/cc/components/containers/rect.h"
namespace mediapipe::tasks::vision::utils { namespace mediapipe::tasks::vision::utils {
// Calculates intersection over union for two bounds. // Calculates intersection over union for two bounds.
float CalculateIOU(const components::containers::Bound& a, float CalculateIOU(const components::containers::Rect& a,
const components::containers::Bound& b); const components::containers::Rect& b);
// Calculates area for face bound // Calculates area for face bound
float CalculateArea(const components::containers::Bound& bound); float CalculateArea(const components::containers::Rect& rect);
// Calucates intersection area of two face bounds // Calucates intersection area of two face bounds
float CalculateIntersectionArea(const components::containers::Bound& a, float CalculateIntersectionArea(const components::containers::Rect& a,
const components::containers::Bound& b); const components::containers::Rect& b);
} // namespace mediapipe::tasks::vision::utils } // namespace mediapipe::tasks::vision::utils
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_

View File

@ -30,13 +30,13 @@ public abstract class Landmark {
return new AutoValue_Landmark(x, y, z, normalized); return new AutoValue_Landmark(x, y, z, normalized);
} }
// The x coordniates of the landmark. // The x coordinates of the landmark.
public abstract float x(); public abstract float x();
// The y coordniates of the landmark. // The y coordinates of the landmark.
public abstract float y(); public abstract float y();
// The z coordniates of the landmark. // The z coordinates of the landmark.
public abstract float z(); public abstract float z();
// Whether this landmark is normalized with respect to the image size. // Whether this landmark is normalized with respect to the image size.

View File

@ -117,7 +117,7 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
if (errorListener != null) { if (errorListener != null) {
errorListener.onError(e); errorListener.onError(e);
} else { } else {
Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e); Log.e(TAG, "Error occurs when getting MediaPipe task result. " + e);
} }
} finally { } finally {
for (Packet packet : packets) { for (Packet packet : packets) {

View File

@ -0,0 +1,63 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
# The native library of all MediaPipe text tasks.
cc_binary(
name = "libmediapipe_tasks_text_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_text_jni_lib",
srcs = [":libmediapipe_tasks_text_jni.so"],
alwayslink = 1,
)
android_library(
name = "textclassifier",
srcs = [
"textclassifier/TextClassificationResult.java",
"textclassifier/TextClassifier.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "textclassifier/AndroidManifest.xml",
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.text.textclassifier">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,103 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.text.textclassifier;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
import com.google.mediapipe.tasks.components.containers.Classifications;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the classification results generated by {@link TextClassifier}. */
@AutoValue
public abstract class TextClassificationResult implements TaskResult {
/**
* Creates an {@link TextClassificationResult} instance from a {@link
* ClassificationsProto.ClassificationResult} protobuf message.
*
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
* message.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static TextClassificationResult create(
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
List<Classifications> classifications = new ArrayList<>();
for (ClassificationsProto.Classifications classificationsProto :
classificationResult.getClassificationsList()) {
classifications.add(classificationsFromProto(classificationsProto));
}
return new AutoValue_TextClassificationResult(
timestampMs, Collections.unmodifiableList(classifications));
}
@Override
public abstract long timestampMs();
/** Contains one set of results per classifier head. */
@SuppressWarnings("AutoValueImmutableFields")
public abstract List<Classifications> classifications();
/**
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
*
* @param category the {@link CategoryProto.Category} protobuf message to convert.
*/
static Category categoryFromProto(CategoryProto.Category category) {
return Category.create(
category.getScore(),
category.getIndex(),
category.getCategoryName(),
category.getDisplayName());
}
/**
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
* ClassificationEntry} object.
*
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
*/
static ClassificationEntry classificationEntryFromProto(
ClassificationsProto.ClassificationEntry entry) {
List<Category> categories = new ArrayList<>();
for (CategoryProto.Category category : entry.getCategoriesList()) {
categories.add(categoryFromProto(category));
}
return ClassificationEntry.create(categories, entry.getTimestampMs());
}
/**
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
* Classifications} object.
*
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
* convert.
*/
static Classifications classificationsFromProto(
ClassificationsProto.Classifications classifications) {
List<ClassificationEntry> entries = new ArrayList<>();
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
entries.add(classificationEntryFromProto(entry));
}
return Classifications.create(
entries, classifications.getHeadIndex(), classifications.getHeadName());
}
}

View File

@ -0,0 +1,253 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.text.textclassifier;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.ProtoUtil;
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.text.textclassifier.proto.TextClassifierGraphOptionsProto;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
* Performs classification on text.
*
* <p>This API expects a TFLite model with (optional) <a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata</a> that contains
* the mandatory (described below) input tensors, output tensor, and the optional (but recommended)
* label items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
*
* <p>Metadata is required for models with int32 input tensors because it contains the input process
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
*
* <ul>
* <li>Input tensors
* <ul>
* <li>Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x
* bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input
* signature requires a Bert Tokenizer process unit in the model metadata.
* <li>Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x
* max_seq_len]} representing the input ids. This input signature requires a Regex
* Tokenizer process unit in the model metadata.
* <li>Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code
* [1]} containing the input string.
* </ul>
* <li>At least one output tensor ({@code kTfLiteFloat32}/{@code kBool}) with:
* <ul>
* <li>{@code N} classes and shape {@code [1 x N]}
* <li>optional (but recommended) label map(s) as AssociatedFile-s with type
* TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
* any) is used to fill the {@code class_name} field of the results. The {@code
* display_name} field is filled from the AssociatedFile (if any) whose locale matches
* the {@code display_names_locale} field of the {@code TextClassifierOptions} used at
* creation time ("en" by default, i.e. English). If none of these are available, only
* the {@code index} field of the results will be filled.
* </ul>
* </ul>
*/
public final class TextClassifier implements AutoCloseable {
private static final String TAG = TextClassifier.class.getSimpleName();
private static final String TEXT_IN_STREAM_NAME = "text_in";
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME));
@SuppressWarnings("ConstantCaseForConstants")
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
private final TaskRunner runner;
static {
System.loadLibrary("mediapipe_tasks_text_jni");
ProtoUtil.registerTypeName(
ClassificationsProto.ClassificationResult.class,
"mediapipe.tasks.components.containers.proto.ClassificationResult");
}
/**
* Creates a {@link TextClassifier} instance from a model file and the default {@link
* TextClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the text model with metadata in the assets.
* @throws MediaPipeException if there is is an error during {@link TextClassifier} creation.
*/
public static TextClassifier createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates a {@link TextClassifier} instance from a model file and the default {@link
* TextClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the text model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
*/
public static TextClassifier createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates a {@link TextClassifier} instance from {@link TextClassifierOptions}.
*
* @param context an Android {@link Context}.
* @param options a {@link TextClassifierOptions} instance.
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
*/
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() {
@Override
public TextClassificationResult convertToTaskResult(List<Packet> packets) {
try {
return TextClassificationResult.create(
PacketGetter.getProto(
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.getDefaultInstance()),
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
} catch (IOException e) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
}
}
@Override
public Void convertToTaskInput(List<Packet> packets) {
return null;
}
});
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<TextClassifierOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(options)
.setEnableFlowLimiting(false)
.build(),
handler);
return new TextClassifier(runner);
}
/**
* Constructor to initialize a {@link TextClassifier} from a {@link TaskRunner}.
*
* @param runner a {@link TaskRunner}.
*/
private TextClassifier(TaskRunner runner) {
this.runner = runner;
}
/**
* Performs classification on the input text.
*
* @param inputText a {@link String} for processing.
*/
public TextClassificationResult classify(String inputText) {
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
return (TextClassificationResult) runner.process(inputPackets);
}
/** Closes and cleans up the {@link TextClassifier}. */
@Override
public void close() {
runner.close();
}
/** Options for setting up a {@link TextClassifier}. */
@AutoValue
public abstract static class TextClassifierOptions extends TaskOptions {
/** Builder for {@link TextClassifierOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the text classifier task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
* score threshold, number of results, etc.
*/
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
public abstract TextClassifierOptions build();
}
abstract BaseOptions baseOptions();
abstract Optional<ClassifierOptions> classifierOptions();
public static Builder builder() {
return new AutoValue_TextClassifier_TextClassifierOptions.Builder();
}
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
if (classifierOptions().isPresent()) {
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
}
return CalculatorOptions.newBuilder()
.setExtension(
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
}

View File

@ -15,6 +15,7 @@
package com.google.mediapipe.tasks.vision.gesturerecognizer; package com.google.mediapipe.tasks.vision.gesturerecognizer;
import android.content.Context; import android.content.Context;
import android.graphics.RectF;
import android.os.ParcelFileDescriptor; import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
@ -71,8 +72,10 @@ import java.util.Optional;
public final class GestureRecognizer extends BaseVisionTaskApi { public final class GestureRecognizer extends BaseVisionTaskApi {
private static final String TAG = GestureRecognizer.class.getSimpleName(); private static final String TAG = GestureRecognizer.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in"; private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
private static final List<String> INPUT_STREAMS = private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS = private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList( Collections.unmodifiableList(
Arrays.asList( Arrays.asList(
@ -205,7 +208,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
* @param runningMode a mediapipe vision task {@link RunningMode}. * @param runningMode a mediapipe vision task {@link RunningMode}.
*/ */
private GestureRecognizer(TaskRunner taskRunner, RunningMode runningMode) { private GestureRecognizer(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
} }
/** /**
@ -223,7 +226,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public GestureRecognitionResult recognize(Image inputImage) { public GestureRecognitionResult recognize(Image inputImage) {
return (GestureRecognitionResult) processImageData(inputImage); // TODO: add proper support for rotations.
return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF());
} }
/** /**
@ -244,7 +248,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) { public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) {
return (GestureRecognitionResult) processVideoData(inputImage, inputTimestampMs); // TODO: add proper support for rotations.
return (GestureRecognitionResult)
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
} }
/** /**
@ -266,7 +272,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
* @throws MediaPipeException if there is an internal error. * @throws MediaPipeException if there is an internal error.
*/ */
public void recognizeAsync(Image inputImage, long inputTimestampMs) { public void recognizeAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(inputImage, inputTimestampMs); // TODO: add proper support for rotations.
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
} }
/** Options for setting up an {@link GestureRecognizer}. */ /** Options for setting up an {@link GestureRecognizer}. */
@ -303,18 +310,18 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */ /** Sets the maximum number of hands can be detected by the GestureRecognizer. */
public abstract Builder setNumHands(Integer value); public abstract Builder setNumHands(Integer value);
/** Sets minimum confidence score for the hand detection to be considered successfully */ /** Sets minimum confidence score for the hand detection to be considered successful */
public abstract Builder setMinHandDetectionConfidence(Float value); public abstract Builder setMinHandDetectionConfidence(Float value);
/** Sets minimum confidence score of hand presence score in the hand landmark detection. */ /** Sets minimum confidence score of hand presence score in the hand landmark detection. */
public abstract Builder setMinHandPresenceConfidence(Float value); public abstract Builder setMinHandPresenceConfidence(Float value);
/** Sets the minimum confidence score for the hand tracking to be considered successfully. */ /** Sets the minimum confidence score for the hand tracking to be considered successful. */
public abstract Builder setMinTrackingConfidence(Float value); public abstract Builder setMinTrackingConfidence(Float value);
/** /**
* Sets the minimum confidence score for the gestures to be considered successfully. If < 0, * Sets the minimum confidence score for the gestures to be considered successful. If < 0, the
* the gesture confidence threshold=0.5 for the model is used. * gesture confidence threshold=0.5 for the model is used.
* *
* <p>TODO Note this option is subject to change, after scoring merging * <p>TODO Note this option is subject to change, after scoring merging
* calculator is implemented. * calculator is implemented.
@ -433,8 +440,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder() HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
.setBaseOptions( .setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder() BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE) .setUseStreamMode(runningMode() != RunningMode.IMAGE));
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker())));
minTrackingConfidence() minTrackingConfidence()
.ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence); .ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence);
handLandmarkerGraphOptionsBuilder handLandmarkerGraphOptionsBuilder
@ -465,4 +471,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
.build(); .build();
} }
} }
/** Creates a RectF covering the full image. */
private static RectF buildFullImageRectF() {
return new RectF(0, 0, 1, 1);
}
} }

View File

@ -39,7 +39,6 @@ import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto; import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -176,7 +175,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
ClassificationsProto.ClassificationResult.getDefaultInstance()), ClassificationsProto.ClassificationResult.getDefaultInstance()),
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
} catch (InvalidProtocolBufferException e) { } catch (IOException e) {
throw new MediaPipeException( throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
} }

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.text.textclassifiertest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="textclassifiertest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.text.textclassifiertest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,154 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.text.textclassifier;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.text.textclassifier.TextClassifier.TextClassifierOptions;
import java.util.Arrays;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
/** Test for {@link TextClassifier}/ */
@RunWith(AndroidJUnit4.class)
public class TextClassifierTest {
private static final String BERT_MODEL_FILE = "bert_text_classifier.tflite";
private static final String REGEX_MODEL_FILE =
"test_model_text_classifier_with_regex_tokenizer.tflite";
private static final String STRING_TO_BOOL_MODEL_FILE =
"test_model_text_classifier_bool_output.tflite";
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
@Test
public void create_failsWithMissingModel() throws Exception {
String nonExistentFile = "/path/to/non/existent/file";
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
TextClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), nonExistentFile));
assertThat(exception).hasMessageThat().contains(nonExistentFile);
}
@Test
public void create_failsWithMissingOpResolver() throws Exception {
TextClassifierOptions options =
TextClassifierOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(STRING_TO_BOOL_MODEL_FILE).build())
.build();
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
TextClassifier.createFromOptions(
ApplicationProvider.getApplicationContext(), options));
// TODO: Make MediaPipe InferenceCalculator report the detailed.
// interpreter errors (e.g., "Encountered unresolved custom op").
assertThat(exception)
.hasMessageThat()
.contains("interpreter_builder(&interpreter) == kTfLiteOk");
}
@Test
public void classify_succeedsWithBert() throws Exception {
TextClassifier textClassifier =
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults);
assertCategoriesAre(
negativeResults,
Arrays.asList(
Category.create(0.95630914f, 0, "negative", ""),
Category.create(0.04369091f, 1, "positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults);
assertCategoriesAre(
positiveResults,
Arrays.asList(
Category.create(0.99997187f, 1, "positive", ""),
Category.create(2.8132641E-5f, 0, "negative", "")));
}
@Test
public void classify_succeedsWithFileObject() throws Exception {
TextClassifier textClassifier =
TextClassifier.createFromFile(
ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults);
assertCategoriesAre(
negativeResults,
Arrays.asList(
Category.create(0.95630914f, 0, "negative", ""),
Category.create(0.04369091f, 1, "positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults);
assertHasOneHead(positiveResults);
assertCategoriesAre(
positiveResults,
Arrays.asList(
Category.create(0.99997187f, 1, "positive", ""),
Category.create(2.8132641E-5f, 0, "negative", "")));
}
@Test
public void classify_succeedsWithRegex() throws Exception {
TextClassifier textClassifier =
TextClassifier.createFromFile(
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
assertHasOneHead(negativeResults);
assertCategoriesAre(
negativeResults,
Arrays.asList(
Category.create(0.6647746f, 0, "Negative", ""),
Category.create(0.33522537f, 1, "Positive", "")));
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
assertHasOneHead(positiveResults);
assertCategoriesAre(
positiveResults,
Arrays.asList(
Category.create(0.5120041f, 0, "Negative", ""),
Category.create(0.48799595f, 1, "Positive", "")));
}
private static void assertHasOneHead(TextClassificationResult results) {
assertThat(results.classifications()).hasSize(1);
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
assertThat(results.classifications().get(0).entries()).hasSize(1);
}
private static void assertCategoriesAre(
TextClassificationResult results, List<Category> categories) {
assertThat(results.classifications().get(0).entries().get(0).categories())
.isEqualTo(categories);
}
}

View File

@ -23,5 +23,9 @@ py_library(
testonly = 1, testonly = 1,
srcs = ["test_utils.py"], srcs = ["test_utils.py"],
srcs_version = "PY3", srcs_version = "PY3",
visibility = [
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
"//mediapipe/tasks:internal",
],
deps = ["//mediapipe/python:_framework_bindings"], deps = ["//mediapipe/python:_framework_bindings"],
) )

View File

@ -35,9 +35,11 @@ mediapipe_files(srcs = [
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"deeplabv3.tflite", "deeplabv3.tflite",
"hand_landmark.task",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"left_hands.jpg", "left_hands.jpg",
"left_hands_rotated.jpg",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite",
"mobilenet_v1_0.25_224_1_metadata_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite",
@ -51,7 +53,9 @@ mediapipe_files(srcs = [
"multi_objects_rotated.jpg", "multi_objects_rotated.jpg",
"palm_detection_full.tflite", "palm_detection_full.tflite",
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
"segmentation_input_rotation0.jpg", "segmentation_input_rotation0.jpg",
"selfie_segm_128_128_3.tflite", "selfie_segm_128_128_3.tflite",
@ -64,7 +68,9 @@ mediapipe_files(srcs = [
exports_files( exports_files(
srcs = [ srcs = [
"expected_left_down_hand_landmarks.prototxt", "expected_left_down_hand_landmarks.prototxt",
"expected_left_down_hand_rotated_landmarks.prototxt",
"expected_left_up_hand_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt",
"expected_left_up_hand_rotated_landmarks.prototxt",
"expected_right_down_hand_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt",
"expected_right_up_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt",
], ],
@ -84,11 +90,14 @@ filegroup(
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"left_hands.jpg", "left_hands.jpg",
"left_hands_rotated.jpg",
"mozart_square.jpg", "mozart_square.jpg",
"multi_objects.jpg", "multi_objects.jpg",
"multi_objects_rotated.jpg", "multi_objects_rotated.jpg",
"pointing_up.jpg", "pointing_up.jpg",
"pointing_up_rotated.jpg",
"right_hands.jpg", "right_hands.jpg",
"right_hands_rotated.jpg",
"segmentation_golden_rotation0.png", "segmentation_golden_rotation0.png",
"segmentation_input_rotation0.jpg", "segmentation_input_rotation0.jpg",
"selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_128_128_3_expected_mask.jpg",
@ -109,6 +118,7 @@ filegroup(
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
"deeplabv3.tflite", "deeplabv3.tflite",
"hand_landmark.task",
"hand_landmark_full.tflite", "hand_landmark_full.tflite",
"hand_landmark_lite.tflite", "hand_landmark_lite.tflite",
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
@ -129,12 +139,17 @@ filegroup(
name = "test_protos", name = "test_protos",
srcs = [ srcs = [
"expected_left_down_hand_landmarks.prototxt", "expected_left_down_hand_landmarks.prototxt",
"expected_left_down_hand_rotated_landmarks.prototxt",
"expected_left_up_hand_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt",
"expected_left_up_hand_rotated_landmarks.prototxt",
"expected_right_down_hand_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt",
"expected_right_up_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt",
"hand_detector_result_one_hand.pbtxt", "hand_detector_result_one_hand.pbtxt",
"hand_detector_result_one_hand_rotated.pbtxt",
"hand_detector_result_two_hands.pbtxt", "hand_detector_result_two_hands.pbtxt",
"pointing_up_landmarks.pbtxt", "pointing_up_landmarks.pbtxt",
"pointing_up_rotated_landmarks.pbtxt",
"thumb_up_landmarks.pbtxt", "thumb_up_landmarks.pbtxt",
"thumb_up_rotated_landmarks.pbtxt",
], ],
) )

View File

@ -0,0 +1,84 @@
landmark {
x: 0.9259716
y: 0.18969846
}
landmark {
x: 0.88135517
y: 0.28856543
}
landmark {
x: 0.7600651
y: 0.3578236
}
landmark {
x: 0.62631166
y: 0.40490413
}
landmark {
x: 0.5374573
y: 0.45170194
}
landmark {
x: 0.57372385
y: 0.29924914
}
landmark {
x: 0.36731184
y: 0.33081773
}
landmark {
x: 0.24132833
y: 0.34759054
}
landmark {
x: 0.13690609
y: 0.35727677
}
landmark {
x: 0.5535803
y: 0.2398035
}
landmark {
x: 0.31834763
y: 0.24999242
}
landmark {
x: 0.16748133
y: 0.25625145
}
landmark {
x: 0.050747424
y: 0.25991398
}
landmark {
x: 0.56593156
y: 0.1867483
}
landmark {
x: 0.3543046
y: 0.17923892
}
landmark {
x: 0.21360746
y: 0.17454882
}
landmark {
x: 0.11110917
y: 0.17232567
}
landmark {
x: 0.5948908
y: 0.14024714
}
landmark {
x: 0.42692152
y: 0.11949824
}
landmark {
x: 0.32239118
y: 0.106370345
}
landmark {
x: 0.23672739
y: 0.09432885
}

View File

@ -0,0 +1,84 @@
landmark {
x: 0.06676084
y: 0.8095678
}
landmark {
x: 0.11359626
y: 0.71148247
}
landmark {
x: 0.23572624
y: 0.6414506
}
landmark {
x: 0.37323278
y: 0.5959156
}
landmark {
x: 0.46243322
y: 0.55125874
}
landmark {
x: 0.4205411
y: 0.69531494
}
landmark {
x: 0.62798893
y: 0.66715276
}
landmark {
x: 0.7568023
y: 0.65208924
}
landmark {
x: 0.86370826
y: 0.6437276
}
landmark {
x: 0.445136
y: 0.75394773
}
landmark {
x: 0.6787485
y: 0.745853
}
landmark {
x: 0.8290694
y: 0.7412988
}
landmark {
x: 0.94454145
y: 0.7384017
}
landmark {
x: 0.43516788
y: 0.8082166
}
landmark {
x: 0.6459554
y: 0.81768996
}
landmark {
x: 0.7875173
y: 0.825062
}
landmark {
x: 0.89249825
y: 0.82850707
}
landmark {
x: 0.40665048
y: 0.8567925
}
landmark {
x: 0.57228816
y: 0.8802181
}
landmark {
x: 0.6762071
y: 0.8941581
}
landmark {
x: 0.76453924
y: 0.90583205
}

View File

@ -0,0 +1,33 @@
detections {
label: "Palm"
score: 0.97115
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.5198178
ymin: 0.6467485
width: 0.42467535
height: 0.22546273
}
}
}
detections {
label: "Palm"
score: 0.96701413
location_data {
format: RELATIVE_BOUNDING_BOX
relative_bounding_box {
xmin: 0.024490356
ymin: 0.12620124
width: 0.43832153
height: 0.23269764
}
}
}
hand_rects {
x_center: 0.5760683
y_center: 0.6829921
height: 0.5862031
width: 1.1048855
rotation: -0.8250832
}

Binary file not shown.

View File

@ -0,0 +1,223 @@
classifications {
classification {
score: 1.0
label: "Left"
display_name: "Left"
}
}
landmarks {
landmark {
x: 0.25546086
y: 0.47584262
z: 1.835341e-07
}
landmark {
x: 0.3363011
y: 0.54135
z: -0.041144375
}
landmark {
x: 0.4375146
y: 0.57881975
z: -0.06807727
}
landmark {
x: 0.49603376
y: 0.5263966
z: -0.09387612
}
landmark {
x: 0.5022822
y: 0.4413827
z: -0.1189948
}
landmark {
x: 0.5569452
y: 0.4724485
z: -0.05138246
}
landmark {
x: 0.6687125
y: 0.47918057
z: -0.09121969
}
landmark {
x: 0.73666537
y: 0.48318353
z: -0.11703273
}
landmark {
x: 0.7998315
y: 0.4741413
z: -0.1386424
}
landmark {
x: 0.5244063
y: 0.39292705
z: -0.061040796
}
landmark {
x: 0.57215345
y: 0.41514704
z: -0.11967233
}
landmark {
x: 0.4724468
y: 0.45553637
z: -0.13287684
}
landmark {
x: 0.43794966
y: 0.45210314
z: -0.13210714
}
landmark {
x: 0.47838163
y: 0.33329
z: -0.07421263
}
landmark {
x: 0.51081127
y: 0.35479474
z: -0.13596693
}
landmark {
x: 0.42433846
y: 0.40486792
z: -0.121291734
}
landmark {
x: 0.40280548
y: 0.39977497
z: -0.09928809
}
landmark {
x: 0.42269367
y: 0.2798249
z: -0.09064263
}
landmark {
x: 0.45849988
y: 0.3069861
z: -0.12894689
}
landmark {
x: 0.40754712
y: 0.35153976
z: -0.109160855
}
landmark {
x: 0.38855004
y: 0.3467068
z: -0.08820164
}
}
world_landmarks {
landmark {
x: -0.08568013
y: 0.016593203
z: 0.036527164
}
landmark {
x: -0.0565372
y: 0.041761592
z: 0.019493781
}
landmark {
x: -0.031365488
y: 0.05031186
z: 0.0025481891
}
landmark {
x: -0.008534161
y: 0.04286737
z: -0.024755282
}
landmark {
x: -0.0047254
y: 0.015748458
z: -0.035581928
}
landmark {
x: 0.013083893
y: 0.024668094
z: 0.0035934823
}
landmark {
x: 0.04149521
y: 0.024621274
z: -0.0030611698
}
landmark {
x: 0.06257473
y: 0.025388625
z: -0.010340984
}
landmark {
x: 0.08009179
y: 0.023082614
z: -0.03162942
}
landmark {
x: 0.006135068
y: 0.000696786
z: 0.0048212176
}
landmark {
x: 0.01678449
y: 0.0067061195
z: -0.029920919
}
landmark {
x: -0.008948593
y: 0.016808286
z: -0.03755109
}
landmark {
x: -0.01789449
y: 0.0153161455
z: -0.012059977
}
landmark {
x: -0.0061980113
y: -0.017872887
z: -0.002366997
}
landmark {
x: -0.004643807
y: -0.0108282855
z: -0.034515083
}
landmark {
x: -0.027603384
y: 0.003529715
z: -0.033665676
}
landmark {
x: -0.035679806
y: 0.0038255951
z: -0.008094264
}
landmark {
x: -0.02957782
y: -0.031701155
z: -0.008180461
}
landmark {
x: -0.020741666
y: -0.02506058
z: -0.026839724
}
landmark {
x: -0.0310834
y: -0.009496164
z: -0.032422185
}
landmark {
x: -0.037420202
y: -0.012883307
z: -0.017971724
}
}

View File

@ -0,0 +1,223 @@
classifications {
classification {
score: 1.0
label: "Left"
display_name: "Left"
}
}
landmarks {
landmark {
x: 0.3283601
y: 0.63773525
z: -3.2280354e-07
}
landmark {
x: 0.46280807
y: 0.6339767
z: -0.06408348
}
landmark {
x: 0.5831279
y: 0.57430106
z: -0.08583106
}
landmark {
x: 0.6689471
y: 0.49959752
z: -0.09886064
}
landmark {
x: 0.74378216
y: 0.47357544
z: -0.09680563
}
landmark {
x: 0.5233122
y: 0.41020474
z: -0.038088404
}
landmark {
x: 0.5296913
y: 0.3372598
z: -0.08874837
}
landmark {
x: 0.49039274
y: 0.43994758
z: -0.102315836
}
landmark {
x: 0.4824569
y: 0.47969607
z: -0.1030014
}
landmark {
x: 0.4451338
y: 0.39520803
z: -0.02177739
}
landmark {
x: 0.4410001
y: 0.34107083
z: -0.07294245
}
landmark {
x: 0.4162798
y: 0.46102384
z: -0.07746907
}
landmark {
x: 0.43492994
y: 0.47154287
z: -0.07404131
}
landmark {
x: 0.37671578
y: 0.39535576
z: -0.016277775
}
landmark {
x: 0.36978847
y: 0.34265152
z: -0.07346253
}
landmark {
x: 0.3559884
y: 0.44905427
z: -0.057693005
}
landmark {
x: 0.37711847
y: 0.46414754
z: -0.03662908
}
landmark {
x: 0.3142985
y: 0.3942253
z: -0.0152847925
}
landmark {
x: 0.30000874
y: 0.35543376
z: -0.046002634
}
landmark {
x: 0.30002704
y: 0.42357764
z: -0.032671776
}
landmark {
x: 0.31079838
y: 0.44218025
z: -0.016200554
}
}
world_landmarks {
landmark {
x: -0.030687196
y: 0.0678545
z: 0.051061403
}
landmark {
x: 0.0047719833
y: 0.06330968
z: 0.018945374
}
landmark {
x: 0.039799504
y: 0.054109577
z: 0.007930638
}
landmark {
x: 0.069374144
y: 0.035063196
z: 2.2522348e-05
}
landmark {
x: 0.087818466
y: 0.018390425
z: 0.004055788
}
landmark {
x: 0.02810654
y: 0.0043561812
z: -0.0038672548
}
landmark {
x: 0.025270049
y: -0.0039896416
z: -0.032991238
}
landmark {
x: 0.020414166
y: 0.006768506
z: -0.032724563
}
landmark {
x: 0.016415983
y: 0.024563588
z: -0.0058115427
}
landmark {
x: 0.0038743173
y: -0.0044466974
z: 0.0024876352
}
landmark {
x: 0.0041790796
y: -0.0115309935
z: -0.03532454
}
landmark {
x: -0.0016900161
y: 0.015519895
z: -0.03596156
}
landmark {
x: 0.004309217
y: 0.01917039
z: 0.003907912
}
landmark {
x: -0.016969737
y: -0.005584497
z: 0.0034258277
}
landmark {
x: -0.016737012
y: -0.01159037
z: -0.02876696
}
landmark {
x: -0.018165365
y: 0.01376111
z: -0.026835402
}
landmark {
x: -0.012430167
y: 0.02064222
z: -0.00087265146
}
landmark {
x: -0.043247573
y: 0.0011161827
z: 0.0056269006
}
landmark {
x: -0.038128495
y: -0.011477032
z: -0.016374081
}
landmark {
x: -0.034920715
y: 0.005510211
z: -0.029714659
}
landmark {
x: -0.03815982
y: 0.011989757
z: -0.014853194
}
}

View File

@ -151,7 +151,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_dummy_gesture_recognizer_task", name = "com_google_mediapipe_dummy_gesture_recognizer_task",
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e", sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665524417056146"], urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665707319890725"],
) )
http_file( http_file(
@ -166,12 +166,24 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_landmarks.prototxt?generation=1661875720230540"], urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_landmarks.prototxt?generation=1661875720230540"],
) )
http_file(
name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt",
sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"],
)
http_file( http_file(
name = "com_google_mediapipe_expected_left_up_hand_landmarks_prototxt", name = "com_google_mediapipe_expected_left_up_hand_landmarks_prototxt",
sha256 = "1353ba617c4f048083618587cd23a8a22115f634521c153d4e1bd1ebd4f49dd7", sha256 = "1353ba617c4f048083618587cd23a8a22115f634521c153d4e1bd1ebd4f49dd7",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_landmarks.prototxt?generation=1661875726008879"], urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_landmarks.prototxt?generation=1661875726008879"],
) )
http_file(
name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt",
sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a",
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"],
)
http_file( http_file(
name = "com_google_mediapipe_expected_right_down_hand_landmarks_prototxt", name = "com_google_mediapipe_expected_right_down_hand_landmarks_prototxt",
sha256 = "f281b745175aaa7f458def6cf4c89521fb56302dd61a05642b3b4a4f237ffaa3", sha256 = "f281b745175aaa7f458def6cf4c89521fb56302dd61a05642b3b4a4f237ffaa3",
@ -250,6 +262,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand.pbtxt?generation=1662745351291628"], urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand.pbtxt?generation=1662745351291628"],
) )
http_file(
name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt",
sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"],
)
http_file( http_file(
name = "com_google_mediapipe_hand_detector_result_two_hands_pbtxt", name = "com_google_mediapipe_hand_detector_result_two_hands_pbtxt",
sha256 = "2589cb08b0ee027dc24649fe597adcfa2156a21d12ea2480f83832714ebdf95f", sha256 = "2589cb08b0ee027dc24649fe597adcfa2156a21d12ea2480f83832714ebdf95f",
@ -268,6 +286,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark_lite.tflite?generation=1661875766398729"], urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark_lite.tflite?generation=1661875766398729"],
) )
http_file(
name = "com_google_mediapipe_hand_landmark_task",
sha256 = "dd830295598e48e6bbbdf22fd9e69538fa07768106cd9ceb04d5462ca7e38c95",
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.task?generation=1665707323647357"],
)
http_file( http_file(
name = "com_google_mediapipe_hand_recrop_tflite", name = "com_google_mediapipe_hand_recrop_tflite",
sha256 = "67d996ce96f9d36fe17d2693022c6da93168026ab2f028f9e2365398d8ac7d5d", sha256 = "67d996ce96f9d36fe17d2693022c6da93168026ab2f028f9e2365398d8ac7d5d",
@ -346,6 +370,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands.jpg?generation=1661875796949017"], urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands.jpg?generation=1661875796949017"],
) )
http_file(
name = "com_google_mediapipe_left_hands_rotated_jpg",
sha256 = "8609c6202bca43a99bbf23fa8e687e49fa525e89481152e4c0987f46d60d7931",
urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands_rotated.jpg?generation=1666037068103465"],
)
http_file( http_file(
name = "com_google_mediapipe_mobilebert_embedding_with_metadata_tflite", name = "com_google_mediapipe_mobilebert_embedding_with_metadata_tflite",
sha256 = "fa47142dcc6f446168bc672f2df9605b6da5d0c0d6264e9be62870282365b95c", sha256 = "fa47142dcc6f446168bc672f2df9605b6da5d0c0d6264e9be62870282365b95c",
@ -538,6 +568,18 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1665174976408451"], urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1665174976408451"],
) )
http_file(
name = "com_google_mediapipe_pointing_up_rotated_jpg",
sha256 = "50ff66f50281207072a038e5bb6648c43f4aacbfb8204a4d2591868756aaeff1",
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated.jpg?generation=1666037072219697"],
)
http_file(
name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt",
sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de",
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"],
)
http_file( http_file(
name = "com_google_mediapipe_pose_detection_tflite", name = "com_google_mediapipe_pose_detection_tflite",
sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85",
@ -574,6 +616,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands.jpg?generation=1661875908672404"], urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands.jpg?generation=1661875908672404"],
) )
http_file(
name = "com_google_mediapipe_right_hands_rotated_jpg",
sha256 = "b3bdf692f0d54b86c8b67e6d1286dd0078fbe6e9dfcd507b187e3bd8b398c0f9",
urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands_rotated.jpg?generation=1666037076873345"],
)
http_file( http_file(
name = "com_google_mediapipe_score_calibration_file_meta_json", name = "com_google_mediapipe_score_calibration_file_meta_json",
sha256 = "6a3c305620371f662419a496f75be5a10caebca7803b1e99d8d5d22ba51cda94", sha256 = "6a3c305620371f662419a496f75be5a10caebca7803b1e99d8d5d22ba51cda94",
@ -718,6 +766,12 @@ def external_files():
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1665174979747784"], urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1665174979747784"],
) )
http_file(
name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt",
sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102",
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"],
)
http_file( http_file(
name = "com_google_mediapipe_two_heads_16000_hz_mono_wav", name = "com_google_mediapipe_two_heads_16000_hz_mono_wav",
sha256 = "a291a9c22c39bba30138a26915e154a96286ba6ca3b413053123c504a58cce3b", sha256 = "a291a9c22c39bba30138a26915e154a96286ba6ca3b413053123c504a58cce3b",