diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 26ada44bf..8fbc99829 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1410,3 +1410,45 @@ cc_library( ], alwayslink = 1, ) + +mediapipe_proto_library( + name = "bypass_calculator_proto", + srcs = ["bypass_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "bypass_calculator", + srcs = ["bypass_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":bypass_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "bypass_calculator_test", + srcs = ["bypass_calculator_test.cc"], + deps = [ + ":bypass_calculator", + ":pass_through_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/mediapipe/calculators/core/bypass_calculator.cc b/mediapipe/calculators/core/bypass_calculator.cc new file mode 100644 index 000000000..86dcfc0e1 --- /dev/null +++ b/mediapipe/calculators/core/bypass_calculator.cc @@ -0,0 +1,161 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "mediapipe/calculators/core/bypass_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { + +using mediapipe::BypassCalculatorOptions; + +// Defines a "bypass" channel to use in place of a disabled feature subgraph. +// By default, all inputs are discarded and all outputs are ignored. +// Certain input streams can be passed to corresponding output streams +// by specifying them in "pass_input_stream" and "pass_output_stream" options. +// All output streams are updated with timestamp bounds indicating completed +// output. +// +// Note that this calculator is designed for use as a contained_node in a +// SwitchContainer. For this reason, any input and output tags are accepted, +// and stream semantics are specified through BypassCalculatorOptions. +// +// Example config: +// node { +// calculator: "BypassCalculator" +// input_stream: "APPEARANCES:appearances_post_facenet" +// input_stream: "VIDEO:video_frame" +// input_stream: "FEATURE_CONFIG:feature_config" +// input_stream: "ENABLE:gaze_enabled" +// output_stream: "APPEARANCES:analyzed_appearances" +// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" +// node_options: { +// [type.googleapis.com/mediapipe.BypassCalculatorOptions] { +// pass_input_stream: "APPEARANCES" +// pass_output_stream: "APPEARANCES" +// } +// } +// } +// +class BypassCalculator : public Node { + public: + static constexpr mediapipe::api2::Input::Optional kNotNeeded{"N_N_"}; + MEDIAPIPE_NODE_CONTRACT(kNotNeeded); + using IdMap = std::map; + + // Returns the map of passthrough input and output stream ids. + static absl::StatusOr GetPassMap( + const BypassCalculatorOptions& options, const tool::TagMap& input_map, + const tool::TagMap& output_map) { + IdMap result; + auto& input_streams = options.pass_input_stream(); + auto& output_streams = options.pass_output_stream(); + int size = std::min(input_streams.size(), output_streams.size()); + for (int i = 0; i < size; ++i) { + std::pair in_tag, out_tag; + MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i), + &in_tag.first, &in_tag.second)); + MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i), + &out_tag.first, &out_tag.second)); + auto input_id = input_map.GetId(in_tag.first, in_tag.second); + auto output_id = output_map.GetId(out_tag.first, out_tag.second); + result[input_id] = output_id; + } + return result; + } + + // Identifies all specified streams as "Any" packet type. + // Identifies passthrough streams as "Same" packet type. + static absl::Status UpdateContract(CalculatorContract* cc) { + auto options = cc->Options(); + RET_CHECK_EQ(options.pass_input_stream().size(), + options.pass_output_stream().size()); + ASSIGN_OR_RETURN( + auto pass_streams, + GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap())); + std::set pass_out; + for (auto entry : pass_streams) { + pass_out.insert(entry.second); + cc->Inputs().Get(entry.first).SetAny(); + cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first)); + } + for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) { + if (pass_streams.count(id) == 0) { + cc->Inputs().Get(id).SetAny(); + } + } + for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) { + if (pass_out.count(id) == 0) { + cc->Outputs().Get(id).SetAny(); + } + } + return absl::OkStatus(); + } + + // Saves the map of passthrough input and output stream ids. + absl::Status Open(CalculatorContext* cc) override { + auto options = cc->Options(); + ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(), + *cc->Outputs().TagMap())); + return absl::OkStatus(); + } + + // Copies packets between passthrough input and output streams. + // Updates timestamp bounds on all output streams. + absl::Status Process(CalculatorContext* cc) override { + std::set pass_out; + for (auto entry : pass_streams_) { + pass_out.insert(entry.second); + auto& packet = cc->Inputs().Get(entry.first).Value(); + if (packet.Timestamp() == cc->InputTimestamp()) { + cc->Outputs().Get(entry.first).AddPacket(packet); + } + } + Timestamp bound = cc->InputTimestamp().NextAllowedInStream(); + for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) { + if (pass_out.count(id) == 0) { + cc->Outputs().Get(id).SetNextTimestampBound( + std::max(cc->Outputs().Get(id).NextTimestampBound(), bound)); + } + } + return absl::OkStatus(); + } + + // Close all output streams. + absl::Status Close(CalculatorContext* cc) override { + for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) { + cc->Outputs().Get(id).Close(); + } + return absl::OkStatus(); + } + + private: + IdMap pass_streams_; +}; + +MEDIAPIPE_REGISTER_NODE(BypassCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/bypass_calculator.proto b/mediapipe/calculators/core/bypass_calculator.proto new file mode 100644 index 000000000..1a273edb2 --- /dev/null +++ b/mediapipe/calculators/core/bypass_calculator.proto @@ -0,0 +1,31 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message BypassCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional BypassCalculatorOptions ext = 481259677; + } + + // Names an input stream or streams to pass through, by "TAG:index". + repeated string pass_input_stream = 1; + + // Names an output stream or streams to pass through, by "TAG:index". + repeated string pass_output_stream = 2; +} diff --git a/mediapipe/calculators/core/bypass_calculator_test.cc b/mediapipe/calculators/core/bypass_calculator_test.cc new file mode 100644 index 000000000..4d1cd8f79 --- /dev/null +++ b/mediapipe/calculators/core/bypass_calculator_test.cc @@ -0,0 +1,302 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +// A graph with using a BypassCalculator to pass through and ignore +// most of its inputs and outputs. +constexpr char kTestGraphConfig1[] = R"pb( + type: "AppearancesPassThroughSubgraph" + input_stream: "APPEARANCES:appearances" + input_stream: "VIDEO:video_frame" + input_stream: "FEATURE_CONFIG:feature_config" + output_stream: "APPEARANCES:passthrough_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output" + + node { + calculator: "BypassCalculator" + input_stream: "PASS:appearances" + input_stream: "TRUNCATE:0:video_frame" + input_stream: "TRUNCATE:1:feature_config" + output_stream: "PASS:passthrough_appearances" + output_stream: "TRUNCATE:passthrough_federated_gaze_output" + node_options: { + [type.googleapis.com/mediapipe.BypassCalculatorOptions] { + pass_input_stream: "PASS" + pass_output_stream: "PASS" + } + } + } +)pb"; + +// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel +// for input frames and appearances. +constexpr char kTestGraphConfig2[] = R"pb( + input_stream: "VIDEO_FULL_RES:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "GAZE_ENABLED:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + + node { + calculator: "SwitchContainer" + input_stream: "VIDEO:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "ENABLE:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + options { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { calculator: "AppearancesPassThroughSubgraph" } + } + } + } +)pb"; + +// A graph with using BypassCalculator as a do-nothing channel +// for input frames and appearances. +constexpr char kTestGraphConfig3[] = R"pb( + input_stream: "VIDEO_FULL_RES:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "GAZE_ENABLED:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + + node { + calculator: "SwitchContainer" + input_stream: "VIDEO:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "ENABLE:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + options { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { + calculator: "BypassCalculator" + node_options: { + [type.googleapis.com/mediapipe.BypassCalculatorOptions] { + pass_input_stream: "APPEARANCES" + pass_output_stream: "APPEARANCES" + } + } + } + } + } + } +)pb"; + +// A graph with using BypassCalculator as a disabled-gate +// for input frames and appearances. +constexpr char kTestGraphConfig4[] = R"pb( + input_stream: "VIDEO_FULL_RES:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "GAZE_ENABLED:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + + node { + calculator: "SwitchContainer" + input_stream: "ENABLE:gaze_enabled" + input_stream: "VIDEO:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + output_stream: "VIDEO:video_frame_out" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEATURE_CONFIG:feature_config_out" + options { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { calculator: "BypassCalculator" } + contained_node: { calculator: "PassThroughCalculator" } + } + } + } +)pb"; + +// Reports packet timestamp and string contents, or """. +std::string DebugString(Packet p) { + return absl::StrCat(p.Timestamp().DebugString(), ":", + p.IsEmpty() ? "" : p.Get()); +} + +// Shows a bypass subgraph that passes through one stream. +TEST(BypassCalculatorTest, SubgraphChannel) { + CalculatorGraphConfig config_1 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig1); + CalculatorGraphConfig config_2 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig2); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {})); + + std::vector analyzed_appearances; + MP_ASSERT_OK(graph.ObserveOutputStream( + "analyzed_appearances", + [&](const Packet& p) { + analyzed_appearances.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + std::vector federated_gaze_output; + MP_ASSERT_OK(graph.ObserveOutputStream( + "federated_gaze_output", + [&](const Packet& p) { + federated_gaze_output.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1")); + EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:")); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// Shows a BypassCalculator that passes through one stream. +TEST(BypassCalculatorTest, CalculatorChannel) { + CalculatorGraphConfig config_3 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig3); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config_3}, {})); + + std::vector analyzed_appearances; + MP_ASSERT_OK(graph.ObserveOutputStream( + "analyzed_appearances", + [&](const Packet& p) { + analyzed_appearances.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + std::vector federated_gaze_output; + MP_ASSERT_OK(graph.ObserveOutputStream( + "federated_gaze_output", + [&](const Packet& p) { + federated_gaze_output.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1")); + EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:")); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// Shows a BypassCalculator that discards all inputs when ENABLED is false. +TEST(BypassCalculatorTest, GatedChannel) { + CalculatorGraphConfig config_3 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig4); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config_3}, {})); + + std::vector analyzed_appearances; + MP_ASSERT_OK(graph.ObserveOutputStream( + "analyzed_appearances", + [&](const Packet& p) { + analyzed_appearances.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + std::vector video_frame; + MP_ASSERT_OK(graph.ObserveOutputStream( + "video_frame_out", + [&](const Packet& p) { + video_frame.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + MP_ASSERT_OK(graph.StartRun({})); + + // Close the gate. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "gaze_enabled", MakePacket(false).At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send packets at timestamp 200. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Only timestamps arrive from the BypassCalculator. + EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:")); + EXPECT_THAT(video_frame, testing::ElementsAre("200:")); + + // Open the gate. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "gaze_enabled", MakePacket(true).At(Timestamp(300)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send packets at timestamp 300. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a2").At(Timestamp(300)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v2").At(Timestamp(300)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f2").At(Timestamp(300)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Packets arrive from the PassThroughCalculator. + EXPECT_THAT(analyzed_appearances, + testing::ElementsAre("200:", "300:a2")); + EXPECT_THAT(video_frame, testing::ElementsAre("200:", "300:v2")); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +} // namespace + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/rotation_mode.proto b/mediapipe/calculators/image/rotation_mode.proto index d4859aa4c..7fa4a8eda 100644 --- a/mediapipe/calculators/image/rotation_mode.proto +++ b/mediapipe/calculators/image/rotation_mode.proto @@ -16,6 +16,9 @@ syntax = "proto2"; package mediapipe; +option java_package = "com.google.mediapipe.calculator.proto"; +option java_outer_classname = "RotationModeProto"; + // Counterclockwise rotation. message RotationMode { enum Mode { diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index c52e2e283..ae8a0cbf0 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -253,6 +253,60 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "text_to_tensor_calculator", + srcs = ["text_to_tensor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "text_to_tensor_calculator_test", + srcs = ["text_to_tensor_calculator_test.cc"], + deps = [ + ":text_to_tensor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:options_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "universal_sentence_encoder_preprocessor_calculator", + srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], + deps = [ + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index 5671d7b4d..14de410ff 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -270,10 +270,10 @@ class GlProcessor : public ImageToTensorConverter { Tensor& output_tensor) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && - input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { + input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 && + input.format() != mediapipe::GpuBufferFormat::kRGB24) { return InvalidArgumentError(absl::StrCat( - "Only 4-channel texture input formats are supported, passed format: ", - static_cast(input.format()))); + "Unsupported format: ", static_cast(input.format()))); } const auto& output_shape = output_tensor.shape(); MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); @@ -281,12 +281,13 @@ class GlProcessor : public ImageToTensorConverter { MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext( [this, &output_tensor, &input, &roi, &output_shape, range_min, range_max, tensor_buffer_offset]() -> absl::Status { - constexpr int kRgbaNumChannels = 4; + const int input_num_channels = input.channels(); auto source_texture = gl_helper_.CreateSourceTexture(input); tflite::gpu::gl::GlTexture input_texture( - GL_TEXTURE_2D, source_texture.name(), GL_RGBA, + GL_TEXTURE_2D, source_texture.name(), + input_num_channels == 4 ? GL_RGB : GL_RGBA, source_texture.width() * source_texture.height() * - kRgbaNumChannels * sizeof(uint8_t), + input_num_channels * sizeof(uint8_t), /*layer=*/0, /*owned=*/false); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 06dfd578e..5efd34041 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -174,10 +174,10 @@ class GlProcessor : public ImageToTensorConverter { Tensor& output_tensor) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && - input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { + input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 && + input.format() != mediapipe::GpuBufferFormat::kRGB24) { return InvalidArgumentError(absl::StrCat( - "Only 4-channel texture input formats are supported, passed format: ", - static_cast(input.format()))); + "Unsupported format: ", static_cast(input.format()))); } // TODO: support tensor_buffer_offset > 0 scenario. RET_CHECK_EQ(tensor_buffer_offset, 0) diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator_test.cc b/mediapipe/calculators/tensor/regex_preprocessor_calculator_test.cc new file mode 100644 index 000000000..ef14ef035 --- /dev/null +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator_test.cc @@ -0,0 +1,130 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/sink.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::testing::ElementsAreArray; + +constexpr int kMaxSeqLen = 256; +constexpr char kTestModelPath[] = + "mediapipe/tasks/testdata/text/" + "test_model_text_classifier_with_regex_tokenizer.tflite"; + +absl::StatusOr> RunRegexPreprocessorCalculator( + absl::string_view text) { + auto graph_config = + ParseTextProtoOrDie(absl::Substitute( + R"pb( + input_stream: "text" + output_stream: "tensors" + node { + calculator: "RegexPreprocessorCalculator" + input_stream: "TEXT:text" + input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" + output_stream: "TENSORS:tensors" + options { + [mediapipe.RegexPreprocessorCalculatorOptions.ext] { + max_seq_len: $0 + } + } + } + )pb", + kMaxSeqLen)); + std::vector output_packets; + tool::AddVectorSink("tensors", &graph_config, &output_packets); + + std::string model_buffer = tasks::core::LoadBinaryContent(kTestModelPath); + ASSIGN_OR_RETURN(std::unique_ptr metadata_extractor, + ModelMetadataExtractor::CreateFromModelBuffer( + model_buffer.data(), model_buffer.size())); + // Run the graph. + CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize( + graph_config, + {{"metadata_extractor", + MakePacket(std::move(*metadata_extractor))}})); + MP_RETURN_IF_ERROR(graph.StartRun({})); + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + "text", MakePacket(text).At(Timestamp(0)))); + MP_RETURN_IF_ERROR(graph.WaitUntilIdle()); + + if (output_packets.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "output_packets has size $0, expected 1", output_packets.size())); + } + const std::vector& tensor_vec = + output_packets[0].Get>(); + if (tensor_vec.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "tensor_vec has size $0, expected $1", tensor_vec.size(), 1)); + } + if (tensor_vec[0].element_type() != Tensor::ElementType::kInt32) { + return absl::InvalidArgumentError("Expected tensor element type kInt32"); + } + auto* buffer = tensor_vec[0].GetCpuReadView().buffer(); + std::vector result(buffer, buffer + kMaxSeqLen); + MP_RETURN_IF_ERROR(graph.CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph.WaitUntilDone()); + return result; +} + +TEST(RegexPreprocessorCalculatorTest, TextClassifierModel) { + MP_ASSERT_OK_AND_ASSIGN( + std::vector processed_tensor_values, + RunRegexPreprocessorCalculator("This is the best movie I’ve seen in " + "recent years. Strongly recommend it!")); + static const int expected_result[kMaxSeqLen] = { + 1, 2, 9, 4, 118, 20, 2, 2, 110, 11, 1136, 153, 2, 386, 12}; + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +TEST(RegexPreprocessorCalculatorTest, LongInput) { + std::stringstream long_input; + long_input << "This is the best"; + for (int i = 0; i < kMaxSeqLen; ++i) { + long_input << " best"; + } + long_input << "movie I’ve seen in recent years. Strongly recommend it!"; + MP_ASSERT_OK_AND_ASSIGN(std::vector processed_tensor_values, + RunRegexPreprocessorCalculator(long_input.str())); + std::vector expected_result = {1, 2, 9, 4, 118}; + // "best" id + expected_result.resize(kMaxSeqLen, 118); + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc index 51c3a9a09..5c1c70aa4 100644 --- a/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc @@ -67,9 +67,7 @@ absl::StatusOr RunTextToTensorCalculator(absl::string_view text) { "tensor_vec has size $0, expected 1", tensor_vec.size())); } if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { - return absl::InvalidArgumentError(absl::Substitute( - "tensor has element type $0, expected $1", tensor_vec[0].element_type(), - Tensor::ElementType::kChar)); + return absl::InvalidArgumentError("Expected tensor element type kChar"); } const char* buffer = tensor_vec[0].GetCpuReadView().buffer(); return std::string(buffer, text.length()); diff --git a/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc index d5f252b57..0f4744c90 100644 --- a/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc +++ b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc @@ -88,9 +88,7 @@ RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) { kNumInputTensorsForUniversalSentenceEncoder)); } if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { - return absl::InvalidArgumentError(absl::Substitute( - "tensor has element type $0, expected $1", tensor_vec[0].element_type(), - Tensor::ElementType::kChar)); + return absl::InvalidArgumentError("Expected tensor element type kChar"); } std::vector results; for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) { diff --git a/mediapipe/calculators/tflite/tflite_model_calculator.cc b/mediapipe/calculators/tflite/tflite_model_calculator.cc index c347d07f6..891a9f731 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator.cc @@ -91,9 +91,9 @@ class TfLiteModelCalculator : public CalculatorBase { tflite::DefaultErrorReporter()); model = tflite::FlatBufferModel::BuildFromAllocation( std::move(model_allocation), tflite::DefaultErrorReporter()); -#elif +#else return absl::FailedPreconditionError( - "Loading by file descriptor is not supported on this platform.") + "Loading by file descriptor is not supported on this platform."); #endif } diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index e3d36611c..aec2445b9 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -378,8 +378,11 @@ cc_library( ], }), deps = [ + ":gl_texture_buffer", ":gpu_buffer_format", ":gpu_buffer_storage", + ":image_frame_view", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/strings:str_format", ], ) diff --git a/mediapipe/gpu/gl_context_nsgl.cc b/mediapipe/gpu/gl_context_nsgl.cc index dda74f0ce..561474ad8 100644 --- a/mediapipe/gpu/gl_context_nsgl.cc +++ b/mediapipe/gpu/gl_context_nsgl.cc @@ -78,7 +78,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) { 16, 0}; - pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs]; + pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_2_1]; } if (!pixel_format_) { // On several Forge machines, the default config fails. For now let's do diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index 1f0a23f31..8b47d620b 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -65,6 +65,7 @@ class GlTextureView { friend class GpuBuffer; friend class GlTextureBuffer; friend class GpuBufferStorageCvPixelBuffer; + friend class GpuBufferStorageAhwb; GlTextureView(GlContext* context, GLenum target, GLuint name, int width, int height, std::shared_ptr gpu_buffer, int plane, DetachFn detach, DoneWritingFn done_writing) diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index c207acf60..3fd519b21 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_test_base.h" #include "stb_image.h" diff --git a/mediapipe/model_maker/python/core/data/classification_dataset.py b/mediapipe/model_maker/python/core/data/classification_dataset.py index 9075e46eb..af761d9ea 100644 --- a/mediapipe/model_maker/python/core/data/classification_dataset.py +++ b/mediapipe/model_maker/python/core/data/classification_dataset.py @@ -23,13 +23,17 @@ from mediapipe.model_maker.python.core.data import dataset as ds class ClassificationDataset(ds.Dataset): """DataLoader for classification models.""" - def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any): + def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any): super().__init__(dataset, size) - self.index_to_label = index_to_label + self._index_by_label = index_by_label @property def num_classes(self: ds._DatasetT) -> int: - return len(self.index_to_label) + return len(self._index_by_label) + + @property + def index_by_label(self: ds._DatasetT) -> Any: + return self._index_by_label def split(self: ds._DatasetT, fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]: @@ -44,4 +48,4 @@ class ClassificationDataset(ds.Dataset): Returns: The splitted two sub datasets. """ - return self._split(fraction, self.index_to_label) + return self._split(fraction, self._index_by_label) diff --git a/mediapipe/model_maker/python/core/data/classification_dataset_test.py b/mediapipe/model_maker/python/core/data/classification_dataset_test.py index f8688ab14..0fd8575f4 100644 --- a/mediapipe/model_maker/python/core/data/classification_dataset_test.py +++ b/mediapipe/model_maker/python/core/data/classification_dataset_test.py @@ -12,45 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Tuple, TypeVar + # Dependency imports import tensorflow as tf from mediapipe.model_maker.python.core.data import classification_dataset +_DatasetT = TypeVar( + '_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset') -class ClassificationDataLoaderTest(tf.test.TestCase): + +class ClassificationDatasetTest(tf.test.TestCase): def test_split(self): - class MagicClassificationDataLoader( + class MagicClassificationDataset( classification_dataset.ClassificationDataset): + """A mock classification dataset class for testing purpose. - def __init__(self, dataset, size, index_to_label, value): - super(MagicClassificationDataLoader, - self).__init__(dataset, size, index_to_label) + Attributes: + value: A value variable stored by the mock dataset class for testing. + """ + + def __init__(self, dataset: tf.data.Dataset, size: int, + index_by_label: Any, value: Any): + super().__init__( + dataset=dataset, size=size, index_by_label=index_by_label) self.value = value - def split(self, fraction): - return self._split(fraction, self.index_to_label, self.value) + def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]: + return self._split(fraction, self.index_by_label, self.value) # Some dummy inputs. magic_value = 42 num_classes = 2 - index_to_label = (False, True) + index_by_label = (False, True) # Create data loader from sample data. ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) - data = MagicClassificationDataLoader(ds, len(ds), index_to_label, - magic_value) + data = MagicClassificationDataset( + dataset=ds, + size=len(ds), + index_by_label=index_by_label, + value=magic_value) # Train/Test data split. fraction = .25 - train_data, test_data = data.split(fraction) + train_data, test_data = data.split(fraction=fraction) # `split` should return instances of child DataLoader. - self.assertIsInstance(train_data, MagicClassificationDataLoader) - self.assertIsInstance(test_data, MagicClassificationDataLoader) + self.assertIsInstance(train_data, MagicClassificationDataset) + self.assertIsInstance(test_data, MagicClassificationDataset) # Make sure number of entries are right. self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data)) @@ -59,7 +73,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase): # Make sure attributes propagated correctly. self.assertEqual(train_data.num_classes, num_classes) - self.assertEqual(test_data.index_to_label, index_to_label) + self.assertEqual(test_data.index_by_label, index_by_label) self.assertEqual(train_data.value, magic_value) self.assertEqual(test_data.value, magic_value) diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index 6b366f6dc..c327b7ea9 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model class Classifier(custom_model.CustomModel): """An abstract base class that represents a TensorFlow classifier.""" - def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool, + def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool, full_train: bool): """Initilizes a classifier with its specifications. Args: model_spec: Specification for the model. - index_to_label: A list that map from index to label class name. + index_by_label: A list that map from index to label class name. shuffle: Whether the dataset should be shuffled. full_train: If true, train the model end-to-end including the backbone and the classification layers on top. Otherwise, only train the top classification layers. """ super(Classifier, self).__init__(model_spec, shuffle) - self._index_to_label = index_to_label + self._index_by_label = index_by_label self._full_train = full_train - self._num_classes = len(index_to_label) + self._num_classes = len(index_by_label) def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: """Evaluates the classifier with the provided evaluation dataset. @@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel): label_filepath = os.path.join(export_dir, label_filename) tf.compat.v1.logging.info('Saving labels in %s', label_filepath) with tf.io.gfile.GFile(label_filepath, 'w') as f: - f.write('\n'.join(self._index_to_label)) + f.write('\n'.join(self._index_by_label)) diff --git a/mediapipe/model_maker/python/core/tasks/classifier_test.py b/mediapipe/model_maker/python/core/tasks/classifier_test.py index fbf231d8b..1484e8e86 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier_test.py +++ b/mediapipe/model_maker/python/core/tasks/classifier_test.py @@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase): def setUp(self): super(ClassifierTest, self).setUp() - index_to_label = ['cat', 'dog'] + index_by_label = ['cat', 'dog'] self.model = MockClassifier( model_spec=None, - index_to_label=index_to_label, + index_by_label=index_by_label, shuffle=False, full_train=False) self.model.model = test_util.build_model(input_shape=[4], num_classes=2) diff --git a/mediapipe/model_maker/python/core/tasks/custom_model.py b/mediapipe/model_maker/python/core/tasks/custom_model.py index 2cea4e0a1..66d1494db 100644 --- a/mediapipe/model_maker/python/core/tasks/custom_model.py +++ b/mediapipe/model_maker/python/core/tasks/custom_model.py @@ -21,8 +21,6 @@ import abc import os from typing import Any, Callable, Optional -# Dependency imports - import tensorflow as tf from mediapipe.model_maker.python.core.data import dataset @@ -77,9 +75,9 @@ class CustomModel(abc.ABC): tflite_filepath = os.path.join(export_dir, tflite_filename) # TODO: Populate metadata to the exported TFLite model. model_util.export_tflite( - self._model, - tflite_filepath, - quantization_config, + model=self._model, + tflite_filepath=tflite_filepath, + quantization_config=quantization_config, preprocess=preprocess) tf.compat.v1.logging.info( 'TensorFlow Lite model exported successfully: %s' % tflite_filepath) diff --git a/mediapipe/model_maker/python/core/tasks/custom_model_test.py b/mediapipe/model_maker/python/core/tasks/custom_model_test.py index e693e1275..ad77d4ecd 100644 --- a/mediapipe/model_maker/python/core/tasks/custom_model_test.py +++ b/mediapipe/model_maker/python/core/tasks/custom_model_test.py @@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase): def setUp(self): super(CustomModelTest, self).setUp() - self.model = MockCustomModel(model_spec=None, shuffle=False) - self.model._model = test_util.build_model(input_shape=[4], num_classes=2) + self._model = MockCustomModel(model_spec=None, shuffle=False) + self._model._model = test_util.build_model(input_shape=[4], num_classes=2) def _check_nonempty_file(self, filepath): self.assertTrue(os.path.isfile(filepath)) @@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase): def test_export_tflite(self): export_path = os.path.join(self.get_temp_dir(), 'export/') - self.model.export_tflite(export_dir=export_path) + self._model.export_tflite(export_dir=export_path) self._check_nonempty_file(os.path.join(export_path, 'model.tflite')) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index e4b18b395..2538ec8fa 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -31,20 +31,6 @@ py_library( ], ) -py_library( - name = "image_preprocessing", - srcs = ["image_preprocessing.py"], - srcs_version = "PY3", -) - -py_test( - name = "image_preprocessing_test", - srcs = ["image_preprocessing_test.py"], - python_version = "PY3", - srcs_version = "PY3", - deps = [":image_preprocessing"], -) - py_library( name = "model_util", srcs = ["model_util.py"], diff --git a/mediapipe/model_maker/python/core/utils/loss_functions.py b/mediapipe/model_maker/python/core/utils/loss_functions.py index 17c738a14..5b0aa32bf 100644 --- a/mediapipe/model_maker/python/core/utils/loss_functions.py +++ b/mediapipe/model_maker/python/core/utils/loss_functions.py @@ -56,7 +56,7 @@ class FocalLoss(tf.keras.losses.Loss): class_weight: A weight to apply to the loss, one for each class. The weight is applied for each input where the ground truth label matches. """ - super(tf.keras.losses.Loss, self).__init__() + super().__init__() # Used for clipping min/max values of probability values in y_pred to avoid # NaNs and Infs in computation. self._epsilon = 1e-7 diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 0899a9b1a..e1228eb6e 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -104,8 +104,8 @@ def export_tflite( quantization_config: Configuration for post-training quantization. supported_ops: A list of supported ops in the converted TFLite file. preprocess: A callable to preprocess the representative dataset for - quantization. The callable takes three arguments in order: feature, - label, and is_training. + quantization. The callable takes three arguments in order: feature, label, + and is_training. """ if tflite_filepath is None: raise ValueError( diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py index ce31c1877..35b52eb75 100644 --- a/mediapipe/model_maker/python/core/utils/model_util_test.py +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): model = test_util.build_model(input_shape=[input_dim], num_classes=2) tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') model_util.export_tflite(model, tflite_file) - self._test_tflite(model, tflite_file, input_dim) + test_util.test_tflite( + keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) @parameterized.named_parameters( dict( @@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): input_dim = 16 num_classes = 2 max_input_value = 5 - model = test_util.build_model([input_dim], num_classes) + model = test_util.build_model( + input_shape=[input_dim], num_classes=num_classes) tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') - model_util.export_tflite(model, tflite_file, config) - self._test_tflite( - model, tflite_file, input_dim, max_input_value, atol=1e-00) - self.assertNear(os.path.getsize(tflite_file), model_size, 300) - - def _test_tflite(self, - keras_model: tf.keras.Model, - tflite_model_file: str, - input_dim: int, - max_input_value: int = 1000, - atol: float = 1e-04): - random_input = test_util.create_random_sample( - size=[1, input_dim], high=max_input_value) - random_input = tf.convert_to_tensor(random_input) - + model_util.export_tflite( + model=model, tflite_filepath=tflite_file, quantization_config=config) self.assertTrue( - test_util.is_same_output( - tflite_model_file, keras_model, random_input, atol=atol)) + test_util.test_tflite( + keras_model=model, + tflite_file=tflite_file, + size=[1, input_dim], + high=max_input_value, + atol=1e-00)) + self.assertNear(os.path.getsize(tflite_file), model_size, 300) if __name__ == '__main__': diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py index cac2a0e1f..b402d3793 100644 --- a/mediapipe/model_maker/python/core/utils/test_util.py +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -92,3 +92,32 @@ def is_same_output(tflite_file: str, keras_output = keras_model.predict_on_batch(input_tensors) return np.allclose(lite_output, keras_output, atol=atol) + + +def test_tflite(keras_model: tf.keras.Model, + tflite_file: str, + size: Union[int, List[int]], + high: float = 1, + atol: float = 1e-04) -> bool: + """Verifies if the output of TFLite model and TF Keras model are identical. + + Args: + keras_model: Input TensorFlow Keras model. + tflite_file: Input TFLite model file. + size: Size of the input tesnor. + high: Higher boundary of the values in input tensors. + atol: Absolute tolerance of the difference between the outputs of Keras + model and TFLite model. + + Returns: + True if the output of TFLite model and TF Keras model are identical. + Otherwise, False. + """ + random_input = create_random_sample(size=size, high=high) + random_input = tf.convert_to_tensor(random_input) + + return is_same_output( + tflite_file=tflite_file, + keras_model=keras_model, + input_tensors=random_input, + atol=atol) diff --git a/mediapipe/model_maker/python/vision/core/BUILD b/mediapipe/model_maker/python/vision/core/BUILD new file mode 100644 index 000000000..2658841ae --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library compatibility macro. +# Placeholder for internal Python strict test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "image_preprocessing", + srcs = ["image_preprocessing.py"], +) + +py_test( + name = "image_preprocessing_test", + srcs = ["image_preprocessing_test.py"], + deps = [":image_preprocessing"], +) diff --git a/mediapipe/model_maker/python/vision/core/__init__.py b/mediapipe/model_maker/python/vision/core/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/utils/image_preprocessing.py b/mediapipe/model_maker/python/vision/core/image_preprocessing.py similarity index 98% rename from mediapipe/model_maker/python/core/utils/image_preprocessing.py rename to mediapipe/model_maker/python/vision/core/image_preprocessing.py index 62b34fb27..104ccd9ca 100644 --- a/mediapipe/model_maker/python/core/utils/image_preprocessing.py +++ b/mediapipe/model_maker/python/vision/core/image_preprocessing.py @@ -13,11 +13,7 @@ # limitations under the License. # ============================================================================== """ImageNet preprocessing.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -# Dependency imports import tensorflow as tf IMAGE_SIZE = 224 diff --git a/mediapipe/model_maker/python/core/utils/image_preprocessing_test.py b/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py similarity index 94% rename from mediapipe/model_maker/python/core/utils/image_preprocessing_test.py rename to mediapipe/model_maker/python/vision/core/image_preprocessing_test.py index bc4b44569..0594b4376 100644 --- a/mediapipe/model_maker/python/core/utils/image_preprocessing_test.py +++ b/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py @@ -12,15 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports import numpy as np import tensorflow as tf -from mediapipe.model_maker.python.core.utils import image_preprocessing +from mediapipe.model_maker.python.vision.core import image_preprocessing def _get_preprocessed_image(preprocessor, is_training=False): diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD index a9386d56e..5b4ec2bd1 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/BUILD +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -78,9 +78,9 @@ py_library( ":train_image_classifier_lib", "//mediapipe/model_maker/python/core/data:classification_dataset", "//mediapipe/model_maker/python/core/tasks:classifier", - "//mediapipe/model_maker/python/core/utils:image_preprocessing", "//mediapipe/model_maker/python/core/utils:model_util", "//mediapipe/model_maker/python/core/utils:quantization", + "//mediapipe/model_maker/python/vision/core:image_preprocessing", ], ) diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset.py b/mediapipe/model_maker/python/vision/image_classifier/dataset.py index e57bae3dd..4ae8dcfdd 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset.py @@ -16,7 +16,7 @@ import os import random -from typing import List, Optional, Tuple +from typing import List, Optional import tensorflow as tf import tensorflow_datasets as tfds @@ -24,12 +24,12 @@ from mediapipe.model_maker.python.core.data import classification_dataset def _load_image(path: str) -> tf.Tensor: - """Loads image.""" + """Loads a jpeg/png image and returns an image tensor.""" image_raw = tf.io.read_file(path) image_tensor = tf.cond( - tf.image.is_jpeg(image_raw), - lambda: tf.image.decode_jpeg(image_raw, channels=3), - lambda: tf.image.decode_png(image_raw, channels=3)) + tf.io.is_jpeg(image_raw), + lambda: tf.io.decode_jpeg(image_raw, channels=3), + lambda: tf.io.decode_png(image_raw, channels=3)) return image_tensor @@ -60,11 +60,10 @@ class Dataset(classification_dataset.ClassificationDataset): Args: dirname: Name of the directory containing the data files. - shuffle: boolean, if shuffle, random shuffle data. + shuffle: boolean, if true, random shuffle data. Returns: Dataset containing images and labels and other related info. - Raises: ValueError: if the input data directory is empty. """ @@ -85,55 +84,26 @@ class Dataset(classification_dataset.ClassificationDataset): name for name in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, name))) all_label_size = len(label_names) - label_to_index = dict( + index_by_label = dict( (name, index) for index, name in enumerate(label_names)) all_image_labels = [ - label_to_index[os.path.basename(os.path.dirname(path))] + index_by_label[os.path.basename(os.path.dirname(path))] for path in all_image_paths ] path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) - autotune = tf.data.AUTOTUNE - image_ds = path_ds.map(_load_image, num_parallel_calls=autotune) + image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) - # Loads label. + # Load label label_ds = tf.data.Dataset.from_tensor_slices( tf.cast(all_image_labels, tf.int64)) - # Creates a dataset if (image, label) pairs. + # Create a dataset if (image, label) pairs image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) tf.compat.v1.logging.info( 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, all_label_size, ', '.join(label_names)) - return Dataset(image_label_ds, all_image_size, label_names) - - @classmethod - def load_tf_dataset( - cls, name: str - ) -> Tuple[Optional[classification_dataset.ClassificationDataset], - Optional[classification_dataset.ClassificationDataset], - Optional[classification_dataset.ClassificationDataset]]: - """Loads data from tensorflow_datasets. - - Args: - name: the registered name of the tfds.core.DatasetBuilder. Refer to the - documentation of tfds.load for more details. - - Returns: - A tuple of Datasets for the train/validation/test. - - Raises: - ValueError: if the input tf dataset does not have train/validation/test - labels. - """ - data, info = tfds.load(name, with_info=True) - if 'label' not in info.features: - raise ValueError('info.features need to contain \'label\' key.') - label_names = info.features['label'].names - - train_data = _create_data('train', data, info, label_names) - validation_data = _create_data('validation', data, info, label_names) - test_data = _create_data('test', data, info, label_names) - return train_data, validation_data, test_data + return Dataset( + dataset=image_label_ds, size=all_image_size, index_by_label=label_names) diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py index 3a5d198b4..6a0b696f9 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py @@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase): def test_split(self): ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) - data = dataset.Dataset(ds, 4, ['pos', 'neg']) - train_data, test_data = data.split(0.5) + data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg']) + train_data, test_data = data.split(fraction=0.5) self.assertLen(train_data, 2) for i, elem in enumerate(train_data._dataset): self.assertTrue((elem.numpy() == np.array([i, 1])).all()) self.assertEqual(train_data.num_classes, 2) - self.assertEqual(train_data.index_to_label, ['pos', 'neg']) + self.assertEqual(train_data.index_by_label, ['pos', 'neg']) self.assertLen(test_data, 2) for i, elem in enumerate(test_data._dataset): self.assertTrue((elem.numpy() == np.array([i, 0])).all()) self.assertEqual(test_data.num_classes, 2) - self.assertEqual(test_data.index_to_label, ['pos', 'neg']) + self.assertEqual(test_data.index_by_label, ['pos', 'neg']) def test_from_folder(self): - data = dataset.Dataset.from_folder(self.image_path) + data = dataset.Dataset.from_folder(dirname=self.image_path) self.assertLen(data, 2) self.assertEqual(data.num_classes, 2) - self.assertEqual(data.index_to_label, ['daisy', 'tulips']) + self.assertEqual(data.index_by_label, ['daisy', 'tulips']) for image, label in data.gen_tf_dataset(): self.assertTrue(label.numpy() == 1 or label.numpy() == 0) if label.numpy() == 0: @@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase): self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset) self.assertLen(train_data, 1034) self.assertEqual(train_data.num_classes, 3) - self.assertEqual(train_data.index_to_label, + self.assertEqual(train_data.index_by_label, ['angular_leaf_spot', 'bean_rust', 'healthy']) self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset) self.assertLen(validation_data, 133) self.assertEqual(validation_data.num_classes, 3) - self.assertEqual(validation_data.index_to_label, + self.assertEqual(validation_data.index_by_label, ['angular_leaf_spot', 'bean_rust', 'healthy']) self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset) self.assertLen(test_data, 128) self.assertEqual(test_data.num_classes, 3) - self.assertEqual(test_data.index_to_label, + self.assertEqual(test_data.index_by_label, ['angular_leaf_spot', 'bean_rust', 'healthy']) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py index 7a99f9ae0..dd8929a71 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -20,9 +20,9 @@ import tensorflow_hub as hub from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds from mediapipe.model_maker.python.core.tasks import classifier -from mediapipe.model_maker.python.core.utils import image_preprocessing from mediapipe.model_maker.python.core.utils import model_util from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision.core import image_preprocessing from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib @@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla class ImageClassifier(classifier.Classifier): """ImageClassifier for building image classification model.""" - def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any], + def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any], hparams: hp.HParams): """Initializes ImageClassifier class. Args: model_spec: Specification for the model. - index_to_label: A list that maps from index to label class name. + index_by_label: A list that maps from index to label class name. hparams: The hyperparameters for training image classifier. """ - super(ImageClassifier, self).__init__( + super().__init__( model_spec=model_spec, - index_to_label=index_to_label, + index_by_label=index_by_label, shuffle=hparams.shuffle, full_train=hparams.do_fine_tuning) self._hparams = hparams @@ -81,7 +81,7 @@ class ImageClassifier(classifier.Classifier): spec = ms.SupportedModels.get(model_spec) image_classifier = cls( model_spec=spec, - index_to_label=train_data.index_to_label, + index_by_label=train_data.index_by_label, hparams=hparams) image_classifier._create_model() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py index a7faab5b6..8ed6de7ad 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -60,31 +60,16 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): model_spec=image_classifier.SupportedModels.MOBILENET_V2, hparams=image_classifier.HParams( train_epochs=1, batch_size=1, shuffle=True)), - dict( - testcase_name='resnet_50', - model_spec=image_classifier.SupportedModels.RESNET_50, - hparams=image_classifier.HParams( - train_epochs=1, batch_size=1, shuffle=True)), dict( testcase_name='efficientnet_lite0', model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, hparams=image_classifier.HParams( train_epochs=1, batch_size=1, shuffle=True)), - dict( - testcase_name='efficientnet_lite1', - model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1, - hparams=image_classifier.HParams( - train_epochs=1, batch_size=1, shuffle=True)), dict( testcase_name='efficientnet_lite2', model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2, hparams=image_classifier.HParams( train_epochs=1, batch_size=1, shuffle=True)), - dict( - testcase_name='efficientnet_lite3', - model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3, - hparams=image_classifier.HParams( - train_epochs=1, batch_size=1, shuffle=True)), dict( testcase_name='efficientnet_lite4', model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4, diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py index 4e9565274..ef44f86e6 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py @@ -48,34 +48,17 @@ mobilenet_v2_spec = functools.partial( uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', name='mobilenet_v2') -resnet_50_spec = functools.partial( - ModelSpec, - uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4', - name='resnet_50') - efficientnet_lite0_spec = functools.partial( ModelSpec, uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', name='efficientnet_lite0') -efficientnet_lite1_spec = functools.partial( - ModelSpec, - uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2', - input_image_shape=[240, 240], - name='efficientnet_lite1') - efficientnet_lite2_spec = functools.partial( ModelSpec, uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', input_image_shape=[260, 260], name='efficientnet_lite2') -efficientnet_lite3_spec = functools.partial( - ModelSpec, - uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2', - input_image_shape=[280, 280], - name='efficientnet_lite3') - efficientnet_lite4_spec = functools.partial( ModelSpec, uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', @@ -88,11 +71,8 @@ efficientnet_lite4_spec = functools.partial( class SupportedModels(enum.Enum): """Image classifier model supported by model maker.""" MOBILENET_V2 = mobilenet_v2_spec - RESNET_50 = resnet_50_spec EFFICIENTNET_LITE0 = efficientnet_lite0_spec - EFFICIENTNET_LITE1 = efficientnet_lite1_spec EFFICIENTNET_LITE2 = efficientnet_lite2_spec - EFFICIENTNET_LITE3 = efficientnet_lite3_spec EFFICIENTNET_LITE4 = efficientnet_lite4_spec @classmethod diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py index bacab016e..63f360ab9 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py @@ -30,36 +30,18 @@ class ModelSpecTest(tf.test.TestCase, parameterized.TestCase): expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', expected_name='mobilenet_v2', expected_input_image_shape=[224, 224]), - dict( - testcase_name='resnet_50_spec_test', - model_spec=ms.resnet_50_spec, - expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4', - expected_name='resnet_50', - expected_input_image_shape=[224, 224]), dict( testcase_name='efficientnet_lite0_spec_test', model_spec=ms.efficientnet_lite0_spec, expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', expected_name='efficientnet_lite0', expected_input_image_shape=[224, 224]), - dict( - testcase_name='efficientnet_lite1_spec_test', - model_spec=ms.efficientnet_lite1_spec, - expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2', - expected_name='efficientnet_lite1', - expected_input_image_shape=[240, 240]), dict( testcase_name='efficientnet_lite2_spec_test', model_spec=ms.efficientnet_lite2_spec, expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', expected_name='efficientnet_lite2', expected_input_image_shape=[260, 260]), - dict( - testcase_name='efficientnet_lite3_spec_test', - model_spec=ms.efficientnet_lite3_spec, - expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2', - expected_name='efficientnet_lite3', - expected_input_image_shape=[280, 280]), dict( testcase_name='efficientnet_lite4_spec_test', model_spec=ms.efficientnet_lite4_spec, diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 7939e4e39..e4905546a 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -92,3 +92,29 @@ cc_library( ], alwayslink = 1, ) + +# TODO: Investigate rewriting the build rule to only link +# the Bert Preprocessor if it's needed. +cc_library( + name = "text_preprocessing_graph", + srcs = ["text_preprocessing_graph.cc"], + hdrs = ["text_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:bert_preprocessor_calculator", + "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 33d3e4457..af51d0c37 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -17,8 +17,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) cc_library( - name = "landmarks_detection", - hdrs = ["landmarks_detection.h"], + name = "rect", + hdrs = ["rect.h"], ) cc_library( diff --git a/mediapipe/tasks/cc/components/containers/landmarks_detection.h b/mediapipe/tasks/cc/components/containers/rect.h similarity index 57% rename from mediapipe/tasks/cc/components/containers/landmarks_detection.h rename to mediapipe/tasks/cc/components/containers/rect.h index 7339954d8..3f5432cf2 100644 --- a/mediapipe/tasks/cc/components/containers/landmarks_detection.h +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -13,26 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ -#include - -// Sturcts holding landmarks related data structure for hand landmarker, pose -// detector, face mesher, etc. namespace mediapipe::tasks::components::containers { -// x and y are in [0,1] range with origin in top left in input image space. -// If model provides z, z is in the same scale as x. origin is in the center -// of the face. -struct Landmark { - float x; - float y; - float z; -}; - -// [0, 1] range in input image space -struct Bound { +// Defines a rectangle, used e.g. as part of detection results or as input +// region-of-interest. +// +// The coordinates are normalized wrt the image dimensions, i.e. generally in +// [0,1] but they may exceed these bounds if describing a region overlapping the +// image. The origin is on the top-left corner of the image. +struct Rect { float left; float top; float right; @@ -40,4 +32,4 @@ struct Bound { }; } // namespace mediapipe::tasks::components::containers -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index f3258a606..291dd29fe 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -73,6 +73,7 @@ cc_library( srcs = ["model_task_graph.cc"], hdrs = ["model_task_graph.h"], deps = [ + ":model_asset_bundle_resources", ":model_resources", ":model_resources_cache", ":model_resources_calculator", @@ -163,6 +164,7 @@ cc_library_with_tflite( "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", ], deps = [ + ":model_asset_bundle_resources", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:packet", "//mediapipe/tasks/cc:common", diff --git a/mediapipe/tasks/cc/core/model_resources_cache.cc b/mediapipe/tasks/cc/core/model_resources_cache.cc index 216962bcf..affcb6dea 100644 --- a/mediapipe/tasks/cc/core/model_resources_cache.cc +++ b/mediapipe/tasks/cc/core/model_resources_cache.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -39,12 +40,16 @@ ModelResourcesCache::ModelResourcesCache( graph_op_resolver_packet_ = api2::PacketAdopting(std::move(graph_op_resolver)); } -}; +} bool ModelResourcesCache::Exists(const std::string& tag) const { return model_resources_collection_.contains(tag); } +bool ModelResourcesCache::ModelAssetBundleExists(const std::string& tag) const { + return model_asset_bundle_resources_collection_.contains(tag); +} + absl::Status ModelResourcesCache::AddModelResources( std::unique_ptr model_resources) { if (model_resources == nullptr) { @@ -94,6 +99,62 @@ absl::StatusOr ModelResourcesCache::GetModelResources( return model_resources_collection_.at(tag).get(); } +absl::Status ModelResourcesCache::AddModelAssetBundleResources( + std::unique_ptr model_asset_bundle_resources) { + if (model_asset_bundle_resources == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ModelAssetBundleResources object is null.", + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + const std::string& tag = model_asset_bundle_resources->GetTag(); + if (tag.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ModelAssetBundleResources must have a non-empty tag.", + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + if (ModelAssetBundleExists(tag)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute( + "ModelAssetBundleResources with tag \"$0\" already exists.", tag), + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + model_asset_bundle_resources_collection_.emplace( + tag, std::move(model_asset_bundle_resources)); + return absl::OkStatus(); +} + +absl::Status ModelResourcesCache::AddModelAssetBundleResourcesCollection( + std::vector>& + model_asset_bundle_resources_collection) { + for (auto& model_bundle_resources : model_asset_bundle_resources_collection) { + MP_RETURN_IF_ERROR( + AddModelAssetBundleResources(std::move(model_bundle_resources))); + } + return absl::OkStatus(); +} + +absl::StatusOr +ModelResourcesCache::GetModelAssetBundleResources( + const std::string& tag) const { + if (tag.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ModelAssetBundleResources must be retrieved with a non-empty tag.", + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + if (!ModelAssetBundleExists(tag)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute( + "ModelAssetBundleResources with tag \"$0\" does not exist.", tag), + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + return model_asset_bundle_resources_collection_.at(tag).get(); +} + absl::StatusOr> ModelResourcesCache::GetGraphOpResolverPacket() const { if (graph_op_resolver_packet_.IsEmpty()) { diff --git a/mediapipe/tasks/cc/core/model_resources_cache.h b/mediapipe/tasks/cc/core/model_resources_cache.h index 044ef36b7..32909f93d 100644 --- a/mediapipe/tasks/cc/core/model_resources_cache.h +++ b/mediapipe/tasks/cc/core/model_resources_cache.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -46,6 +47,10 @@ class ModelResourcesCache { // Returns whether the tag exists in the model resources cache. bool Exists(const std::string& tag) const; + // Returns whether the tag of the model asset bundle exists in the model + // resources cache. + bool ModelAssetBundleExists(const std::string& tag) const; + // Adds a ModelResources object into the cache. // The tag of the ModelResources must be unique; the ownership of the // ModelResources will be transferred into the cache. @@ -62,6 +67,23 @@ class ModelResourcesCache { absl::StatusOr GetModelResources( const std::string& tag) const; + // Adds a ModelAssetBundleResources object into the cache. + // The tag of the ModelAssetBundleResources must be unique; the ownership of + // the ModelAssetBundleResources will be transferred into the cache. + absl::Status AddModelAssetBundleResources( + std::unique_ptr model_asset_bundle_resources); + + // Adds a collection of the ModelAssetBundleResources objects into the cache. + // The tag of the each ModelAssetBundleResources must be unique; the ownership + // of every ModelAssetBundleResources will be transferred into the cache. + absl::Status AddModelAssetBundleResourcesCollection( + std::vector>& + model_asset_bundle_resources_collection); + + // Retrieves a const ModelAssetBundleResources pointer by the unique tag. + absl::StatusOr GetModelAssetBundleResources( + const std::string& tag) const; + // Retrieves the graph op resolver packet. absl::StatusOr> GetGraphOpResolverPacket() const; @@ -73,6 +95,11 @@ class ModelResourcesCache { // A collection of ModelResources objects for the models in the graph. absl::flat_hash_map> model_resources_collection_; + + // A collection of ModelAssetBundleResources objects for the model bundles in + // the graph. + absl::flat_hash_map> + model_asset_bundle_resources_collection_; }; // Global service for mediapipe task model resources cache. diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 90a38747c..47334b673 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" @@ -70,6 +71,17 @@ std::string CreateModelResourcesTag(const CalculatorGraphConfig::Node& node) { node_type); } +std::string CreateModelAssetBundleResourcesTag( + const CalculatorGraphConfig::Node& node) { + std::vector names = absl::StrSplit(node.name(), "__"); + std::string node_type = node.calculator(); + std::replace(node_type.begin(), node_type.end(), '.', '_'); + absl::AsciiStrToLower(&node_type); + return absl::StrFormat("%s_%s_model_asset_bundle_resources", + names.back().empty() ? "unnamed" : names.back(), + node_type); +} + } // namespace // Defines the mediapipe task inference unit as a MediaPipe subgraph that @@ -168,6 +180,38 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( return model_resources_cache_service.GetObject().GetModelResources(tag); } +absl::StatusOr +ModelTaskGraph::CreateModelAssetBundleResources( + SubgraphContext* sc, std::unique_ptr external_file) { + auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); + bool has_file_pointer_meta = external_file->has_file_pointer_meta(); + // if external file is set by file pointer, no need to add the model asset + // bundle resources into the model resources service since the memory is + // not owned by this model asset bundle resources. + if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) { + ASSIGN_OR_RETURN( + local_model_asset_bundle_resources_, + ModelAssetBundleResources::Create("", std::move(external_file))); + if (!has_file_pointer_meta) { + LOG(WARNING) + << "A local ModelResources object is created. Please consider using " + "ModelResourcesCacheService to cache the created ModelResources " + "object in the CalculatorGraph."; + } + return local_model_asset_bundle_resources_.get(); + } + const std::string tag = + CreateModelAssetBundleResourcesTag(sc->OriginalNode()); + ASSIGN_OR_RETURN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(tag, std::move(external_file))); + MP_RETURN_IF_ERROR( + model_resources_cache_service.GetObject().AddModelAssetBundleResources( + std::move(model_bundle_resources))); + return model_resources_cache_service.GetObject().GetModelAssetBundleResources( + tag); +} + GenericNode& ModelTaskGraph::AddInference( const ModelResources& model_resources, const proto::Acceleration& acceleration, Graph& graph) const { diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 36016cb89..5ee70e8f3 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -27,6 +27,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -78,6 +79,35 @@ class ModelTaskGraph : public Subgraph { absl::StatusOr CreateModelResources( SubgraphContext* sc, std::unique_ptr external_file); + // If the model resources graph service is available, creates a model asset + // bundle resources object from the subgraph context, and caches the created + // model asset bundle resources into the model resources graph service on + // success. Otherwise, creates a local model asset bundle resources object + // that can only be used in the graph construction stage. The returned model + // resources pointer will provide graph authors with the access to extracted + // model files. + template + absl::StatusOr + CreateModelAssetBundleResources(SubgraphContext* sc) { + auto external_file = std::make_unique(); + external_file->Swap(sc->MutableOptions() + ->mutable_base_options() + ->mutable_model_asset()); + return CreateModelAssetBundleResources(sc, std::move(external_file)); + } + + // If the model resources graph service is available, creates a model asset + // bundle resources object from the subgraph context, and caches the created + // model asset bundle resources into the model resources graph service on + // success. Otherwise, creates a local model asset bundle resources object + // that can only be used in the graph construction stage. Note that the + // external file contents will be moved into the model asset bundle resources + // object on creation. The returned model asset bundle resources pointer will + // provide graph authors with the access to extracted model files. + absl::StatusOr + CreateModelAssetBundleResources( + SubgraphContext* sc, std::unique_ptr external_file); + // Inserts a mediapipe task inference subgraph into the provided // GraphBuilder. The returned node provides the following interfaces to the // the rest of the graph: @@ -95,6 +125,9 @@ class ModelTaskGraph : public Subgraph { private: std::unique_ptr local_model_resources_; + + std::unique_ptr + local_model_asset_bundle_resources_; }; } // namespace core diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.cc b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc index 41d710e14..2c09e1961 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_utils.cc +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include + #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" @@ -162,12 +164,16 @@ absl::Status ExtractFilesfromZipFile( return absl::OkStatus(); } -void SetExternalFile(const std::string_view& file_content, - core::proto::ExternalFile* model_file) { - auto pointer = reinterpret_cast(file_content.data()); - - model_file->mutable_file_pointer_meta()->set_pointer(pointer); - model_file->mutable_file_pointer_meta()->set_length(file_content.length()); +void SetExternalFile(const absl::string_view& file_content, + core::proto::ExternalFile* model_file, bool is_copy) { + if (is_copy) { + std::string str_content{file_content}; + model_file->set_file_content(str_content); + } else { + auto pointer = reinterpret_cast(file_content.data()); + model_file->mutable_file_pointer_meta()->set_pointer(pointer); + model_file->mutable_file_pointer_meta()->set_length(file_content.length()); + } } } // namespace metadata diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.h b/mediapipe/tasks/cc/metadata/utils/zip_utils.h index 28708ba6a..10ad0a5a9 100644 --- a/mediapipe/tasks/cc/metadata/utils/zip_utils.h +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.h @@ -35,10 +35,13 @@ absl::Status ExtractFilesfromZipFile( const char* buffer_data, const size_t buffer_size, absl::flat_hash_map* files); -// Set file_pointer_meta in ExternalFile which is the pointer points to location -// of a file in memory by file_content. -void SetExternalFile(const std::string_view& file_content, - core::proto::ExternalFile* model_file); +// Set the ExternalFile object by file_content in memory. By default, +// `is_copy=false` which means to set `file_pointer_meta` in ExternalFile which +// is the pointer points to location of a file in memory. Otherwise, if +// `is_copy=true`, copy the memory into `file_content` in ExternalFile. +void SetExternalFile(const absl::string_view& file_content, + core::proto::ExternalFile* model_file, + bool is_copy = false); } // namespace metadata } // namespace tasks diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD new file mode 100644 index 000000000..a85538631 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -0,0 +1,84 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_classifier_graph", + srcs = ["text_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components:text_preprocessing_graph", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_calculator", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_library( + name = "text_classifier", + srcs = ["text_classifier.cc"], + hdrs = ["text_classifier.h"], + deps = [ + ":text_classifier_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + +cc_library( + name = "text_classifier_test_utils", + srcs = ["text_classifier_test_utils.cc"], + hdrs = ["text_classifier_test_utils.h"], + visibility = ["//visibility:private"], + deps = [ + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite:mutable_op_resolver", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], +) diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/BUILD b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD new file mode 100644 index 000000000..f2b544d87 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "text_classifier_graph_options_proto", + srcs = ["text_classifier_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto new file mode 100644 index 000000000..8f4d7eea6 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -0,0 +1,38 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.text.text_classifier.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.text.textclassifier.proto"; +option java_outer_classname = "TextClassifierGraphOptionsProto"; + +message TextClassifierGraphOptions { + extend mediapipe.CalculatorOptions { + optional TextClassifierGraphOptions ext = 462704549; + } + // Base options for configuring MediaPipe Tasks, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + optional components.processors.proto.ClassifierOptions classifier_options = 2; +} diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc new file mode 100644 index 000000000..699f15bc0 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc @@ -0,0 +1,104 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" +#include "tensorflow/lite/core/api/op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +namespace { + +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; + +constexpr char kTextStreamName[] = "text_in"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kClassificationResultStreamName[] = "classification_result_out"; +constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// type "TextClassifierGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kSubgraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); + subgraph.Out(kClassificationResultTag) + .SetName(kClassificationResultStreamName) >> + graph.Out(kClassificationResultTag); + return graph.GetConfig(); +} + +// Converts the user-facing TextClassifierOptions struct to the internal +// TextClassifierGraphOptions proto. +std::unique_ptr +ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + auto classifier_options_proto = + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( + &(options->classifier_options))); + options_proto->mutable_classifier_options()->Swap( + classifier_options_proto.get()); + return options_proto; +} + +} // namespace + +absl::StatusOr> TextClassifier::Create( + std::unique_ptr options) { + auto options_proto = ConvertTextClassifierOptionsToProto(options.get()); + return core::TaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver)); +} + +absl::StatusOr TextClassifier::Classify( + absl::string_view text) { + ASSIGN_OR_RETURN( + auto output_packets, + runner_->Process( + {{kTextStreamName, MakePacket(std::string(text))}})); + return output_packets[kClassificationResultStreamName] + .Get(); +} + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h new file mode 100644 index 000000000..b027a9787 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h @@ -0,0 +1,96 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +// The options for configuring a MediaPipe text classifier task. +struct TextClassifierOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + components::processors::ClassifierOptions classifier_options; +}; + +// Performs classification on text. +// +// This API expects a TFLite model with (optional) TFLite Model Metadata that +// contains the mandatory (described below) input tensors, output tensor, +// and the optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for +// models with int32 input tensors because it contains the input process unit +// for the model's Tokenizer. No metadata is required for models with string +// input tensors. +// +// Input tensors: +// (kTfLiteInt32) +// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing +// the input ids, segment ids, and mask ids +// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the +// input ids +// or (kTfLiteString) +// - 1 input tensor that is shapeless or has shape [1] containing the input +// string +// At least one output tensor with: +// (kTfLiteFloat32/kBool) +// - `[1 x N]` array with `N` represents the number of categories. +// - optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS, containing one label per line. The first such +// AssociatedFile (if any) is used to fill the `category_name` field of the +// results. The `display_name` field is filled from the AssociatedFile (if +// any) whose locale matches the `display_names_locale` field of the +// `TextClassifierOptions` used at creation time ("en" by default, i.e. +// English). If none of these are available, only the `index` field of the +// results will be filled. +class TextClassifier : core::BaseTaskApi { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a TextClassifier from the provided `options`. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs classification on the input `text`. + absl::StatusOr Classify( + absl::string_view text); + + // Shuts down the TextClassifier when all the work is done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc new file mode 100644 index 000000000..9706db4d8 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -0,0 +1,162 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::core::ModelResources; + +constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; +constexpr char kTensorsTag[] = "TENSORS"; + +} // namespace + +// A "TextClassifierGraph" performs Natural Language classification (including +// BERT-based text classification). +// - Accepts input text and outputs classification results on CPU. +// +// Inputs: +// TEXT - std::string +// Input text to perform classification on. +// +// Outputs: +// CLASSIFICATION_RESULT - ClassificationResult +// The aggregated classification result object that has 3 dimensions: +// (classification head, classification timestamp, classification category). +// +// Example: +// node { +// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph" +// input_stream: "TEXT:text_in" +// output_stream: "CLASSIFICATION_RESULT:classification_result_out" +// options { +// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// } +// } +// } +class TextClassifierGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN( + const ModelResources* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN( + Source classification_result_out, + BuildTextClassifierTask( + sc->Options(), *model_resources, + graph[Input(kTextTag)], graph)); + classification_result_out >> + graph[Output(kClassificationResultTag)]; + return graph.GetConfig(); + } + + private: + // Adds a mediapipe TextClassifier task graph into the provided + // builder::Graph instance. The TextClassifier task takes an input + // text (std::string) and returns one classification result per output head + // specified by the model. + // + // task_options: the mediapipe tasks TextClassifierGraphOptions proto. + // model_resources: the ModelResources object initialized from a + // TextClassifier model file with model metadata. + // text_in: (std::string) stream to run text classification on. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr> BuildTextClassifierTask( + const proto::TextClassifierGraphOptions& task_options, + const ModelResources& model_resources, Source text_in, + Graph& graph) { + // Adds preprocessing calculators and connects them to the text input + // stream. + auto& preprocessing = + graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); + MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + model_resources, + preprocessing.GetOptions< + tasks::components::proto::TextPreprocessingGraphOptions>())); + text_in >> preprocessing.In(kTextTag); + + // Adds both InferenceCalculator and ModelResourcesCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + // The metadata extractor side-output comes from the + // ModelResourcesCalculator. + inference.SideOut(kMetadataExtractorTag) >> + preprocessing.SideIn(kMetadataExtractorTag); + preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); + + // Adds postprocessing calculators and connects them to the graph output. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); + inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); + + // Outputs the aggregated classification result as the subgraph output + // stream. + return postprocessing[Output( + kClassificationResultTag)]; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::text::text_classifier::TextClassifierGraph); + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc new file mode 100644 index 000000000..5b33f6606 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { +namespace { + +using ::mediapipe::EqualsProto; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::kMediaPipeTasksPayload; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::testing::HasSubstr; +using ::testing::Optional; +using ::testing::proto::Approximately; +using ::testing::proto::IgnoringRepeatedFieldOrdering; +using ::testing::proto::Partially; + +constexpr float kEpsilon = 0.001; +constexpr int kMaxSeqLen = 128; +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; +constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; +constexpr char kTestRegexModelPath[] = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +constexpr char kStringToBoolModelPath[] = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +class TextClassifierTest : public tflite_shims::testing::Test {}; + +TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK(TextClassifier::Create(std::move(options))); +} + +TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) { + auto options = std::make_unique(); + StatusOr> classifier = + TextClassifier::Create(std::move(options)); + + EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + classifier.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(TextClassifierTest, CreateFailsWithMissingModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kInvalidModelPath); + StatusOr> classifier = + TextClassifier::Create(std::move(options)); + + EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound); + EXPECT_THAT(classifier.status().message(), + HasSubstr("Unable to open file at")); + EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); + MP_ASSERT_OK(TextClassifier::Create(std::move(options))); +} + +TEST_F(TextClassifierTest, TextClassifierWithBert) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + ClassificationResult negative_result, + classifier->Classify("unflinchingly bleak and desperate")); + ASSERT_THAT(negative_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "negative" score: 0.956 } + categories { category_name: "positive" score: 0.044 } + } + } + )pb"), + kEpsilon)))); + + MP_ASSERT_OK_AND_ASSIGN( + ClassificationResult positive_result, + classifier->Classify("it's a charming and often affecting journey")); + ASSERT_THAT(positive_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "negative" score: 0.0 } + categories { category_name: "positive" score: 1.0 } + } + } + )pb"), + kEpsilon)))); + MP_ASSERT_OK(classifier->Close()); +} + +TEST_F(TextClassifierTest, TextClassifierWithIntInputs) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result, + classifier->Classify("What a waste of my time.")); + ASSERT_THAT(negative_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "Negative" score: 0.813 } + categories { category_name: "Positive" score: 0.187 } + } + } + )pb"), + kEpsilon)))); + + MP_ASSERT_OK_AND_ASSIGN( + ClassificationResult positive_result, + classifier->Classify("This is the best movie I’ve seen in recent years. " + "Strongly recommend it!")); + ASSERT_THAT(positive_result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "Negative" score: 0.487 } + categories { category_name: "Positive" score: 0.513 } + } + } + )pb"), + kEpsilon)))); + MP_ASSERT_OK(classifier->Close()); +} + +TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath); + options->base_options.op_resolver = CreateCustomResolver(); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, + classifier->Classify("hello")); + ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb( + classifications { + entries { + categories { index: 1 score: 1 } + categories { index: 0 score: 1 } + categories { index: 2 score: 0 } + } + } + )pb")))); +} + +TEST_F(TextClassifierTest, BertLongPositive) { + std::stringstream ss_for_positive_review; + ss_for_positive_review + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kMaxSeqLen; ++i) { + ss_for_positive_review << " long"; + } + ss_for_positive_review << " movie review"; + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result, + classifier->Classify(ss_for_positive_review.str())); + ASSERT_THAT(result, + Partially(IgnoringRepeatedFieldOrdering(Approximately( + EqualsProto(R"pb( + classifications { + entries { + categories { category_name: "negative" score: 0.014 } + categories { category_name: "positive" score: 0.986 } + } + } + )pb"), + kEpsilon)))); + MP_ASSERT_OK(classifier->Close()); +} + +} // namespace +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc new file mode 100644 index 000000000..d12370372 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc @@ -0,0 +1,131 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" +#include "tensorflow/lite/string_util.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace { + +using ::mediapipe::tasks::CreateStatusWithPayload; +using ::tflite::GetInput; +using ::tflite::GetOutput; +using ::tflite::GetString; +using ::tflite::StringRef; + +constexpr absl::string_view kInputStr = "hello"; +constexpr bool kBooleanData[] = {true, true, false}; +constexpr size_t kBooleanDataSize = std::size(kBooleanData); + +// Checks and returns type of a tensor, fails if tensor type is not T. +template +absl::StatusOr AssertAndReturnTypedTensor(const TfLiteTensor* tensor) { + if (!tensor->data.raw) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Tensor (%s) has no raw data.", tensor->name)); + } + + // Checks if data type of tensor is T and returns the pointer casted to T if + // applicable, returns nullptr if tensor type is not T. + // See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType. + if (tensor->type == tflite::typeToTfLiteType()) { + return reinterpret_cast(tensor->data.raw); + } + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.", + tensor->name, tflite::typeToTfLiteType(), + tensor->bytes)); +} + +// Populates tensor with array of data, fails if data type doesn't match tensor +// type or they don't have the same number of elements. +template >>> +absl::Status PopulateTensor(const T* data, int num_elements, + TfLiteTensor* tensor) { + ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor(tensor)); + size_t bytes = num_elements * sizeof(T); + if (tensor->bytes != bytes) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes, + bytes)); + } + std::memcpy(v, data, bytes); + return absl::OkStatus(); +} + +TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteIntArray* dims = TfLiteIntArrayCreate(1); + dims->data[0] = kBooleanDataSize; + return context->ResizeTensor(context, output, dims); +} + +TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, 0); + TF_LITE_ENSURE(context, input_tensor != nullptr); + StringRef input_str_ref = GetString(input_tensor, 0); + std::string input_str(input_str_ref.str, input_str_ref.len); + if (input_str != kInputStr) { + return kTfLiteError; + } + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok()); + return kTfLiteOk; +} + +// This custom op takes a string tensor in and outputs a bool tensor with +// value{true, true, false}, it's used to mimic a real text classification model +// which classifies a string into scores of different categories. +TfLiteRegistration* RegisterStringToBool() { + // Dummy implementation of custom OP + // This op takes string as input and outputs bool[] + static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr, + /* prepare= */ PrepareStringToBool, + /* invoke= */ InvokeStringToBool}; + return &r; +} +} // namespace + +std::unique_ptr CreateCustomResolver() { + tflite::MutableOpResolver resolver; + resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool()); + return std::make_unique(resolver); +} + +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h new file mode 100644 index 000000000..a427b561c --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ + +#include + +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace text { + +// Create a custom MutableOpResolver to provide custom OP implementations to +// mimic classification behavior. +std::unique_ptr CreateCustomResolver(); + +} // namespace text +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD index 985c25cfb..e5b1f0479 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -56,6 +56,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", @@ -91,6 +92,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:model_task_graph", @@ -123,6 +125,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components/containers:gesture_recognition_result", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index a6de4f950..08f7f45d0 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -69,6 +69,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -86,6 +87,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index b70689eaf..277bb170a 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" @@ -38,6 +40,7 @@ namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; constexpr int kFeaturesPerLandmark = 3; @@ -62,6 +65,25 @@ absl::StatusOr NormalizeLandmarkAspectRatio( return normalized_landmarks; } +template +absl::StatusOr RotateLandmarks(const LandmarkListT& landmarks, + float rotation) { + float cos = std::cos(rotation); + // Negate because Y-axis points down and not up. + float sin = std::sin(-rotation); + LandmarkListT rotated_landmarks; + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const auto& old_landmark = landmarks.landmark(i); + float x = old_landmark.x() - 0.5; + float y = old_landmark.y() - 0.5; + auto* new_landmark = rotated_landmarks.add_landmark(); + new_landmark->set_x(x * cos - y * sin + 0.5); + new_landmark->set_y(y * cos + x * sin + 0.5); + new_landmark->set_z(old_landmark.z()); + } + return rotated_landmarks; +} + template absl::StatusOr NormalizeObject(const LandmarkListT& landmarks, int origin_offset) { @@ -134,6 +156,13 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { NormalizeLandmarkAspectRatio(landmarks, width, height)); } + if (cc->Inputs().HasTag(kNormRectTag)) { + RET_CHECK(!cc->Inputs().Tag(kNormRectTag).IsEmpty()); + const auto rotation = + cc->Inputs().Tag(kNormRectTag).Get().rotation(); + ASSIGN_OR_RETURN(landmarks, RotateLandmarks(landmarks, rotation)); + } + const auto& options = cc->Options(); if (options.object_normalization()) { ASSIGN_OR_RETURN( @@ -163,6 +192,8 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { // WORLD_LANDMARKS - World 3d landmarks of one object. Use *either* // LANDMARKS or WORLD_LANDMARKS. // IMAGE_SIZE - (width, height) of the image +// NORM_RECT - Optional NormalizedRect object whose 'rotation' field is used +// to rotate the landmarks. // Output: // LANDMARKS_MATRIX - Matrix for the landmarks. // @@ -185,6 +216,7 @@ class LandmarksToMatrixCalculator : public CalculatorBase { cc->Inputs().Tag(kLandmarksTag).Set().Optional(); cc->Inputs().Tag(kWorldLandmarksTag).Set().Optional(); cc->Inputs().Tag(kImageSizeTag).Set>().Optional(); + cc->Inputs().Tag(kNormRectTag).Set().Optional(); cc->Outputs().Tag(kLandmarksMatrixTag).Set(); return absl::OkStatus(); } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index 8a68d8dae..fe6f1162b 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" @@ -35,6 +37,7 @@ constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; +constexpr char kNormRectTag[] = "NORM_RECT"; template LandmarkListT BuildPseudoLandmarks(int num_landmarks, int offset = 0) { @@ -54,6 +57,7 @@ struct Landmarks2dToMatrixCalculatorTestCase { int object_normalization_origin_offset = -1; float expected_cell_0_2; float expected_cell_1_5; + float rotation; }; using Landmarks2dToMatrixCalculatorTest = @@ -68,6 +72,7 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { calculator: "LandmarksToMatrixCalculator" input_stream: "LANDMARKS:landmarks" input_stream: "IMAGE_SIZE:image_size" + input_stream: "NORM_RECT:norm_rect" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { @@ -91,6 +96,11 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { runner.MutableInputs() ->Tag(kImageSizeTag) .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); + auto norm_rect = std::make_unique(); + norm_rect->set_rotation(test_case.rotation); + runner.MutableInputs() + ->Tag(kNormRectTag) + .packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; @@ -109,12 +119,20 @@ INSTANTIATE_TEST_CASE_P( .base_offset = 0, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.1875f}, + .expected_cell_1_5 = 0.1875f, + .rotation = 0}, {.test_name = "TestWithOffset21", .base_offset = 21, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.1875f}}), + .expected_cell_1_5 = 0.1875f, + .rotation = 0}, + {.test_name = "TestWithRotation", + .base_offset = 0, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.075f, + .expected_cell_1_5 = -0.25f, + .rotation = M_PI / 2.0}}), [](const testing::TestParamInfo< Landmarks2dToMatrixCalculatorTest::ParamType>& info) { return info.param.test_name; @@ -126,6 +144,7 @@ struct LandmarksWorld3dToMatrixCalculatorTestCase { int object_normalization_origin_offset = -1; float expected_cell_0_2; float expected_cell_1_5; + float rotation; }; using LandmarksWorld3dToMatrixCalculatorTest = @@ -140,6 +159,7 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { calculator: "LandmarksToMatrixCalculator" input_stream: "WORLD_LANDMARKS:landmarks" input_stream: "IMAGE_SIZE:image_size" + input_stream: "NORM_RECT:norm_rect" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { @@ -162,6 +182,11 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { runner.MutableInputs() ->Tag(kImageSizeTag) .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); + auto norm_rect = std::make_unique(); + norm_rect->set_rotation(test_case.rotation); + runner.MutableInputs() + ->Tag(kNormRectTag) + .packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; @@ -180,17 +205,26 @@ INSTANTIATE_TEST_CASE_P( .base_offset = 0, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.25}, + .expected_cell_1_5 = 0.25, + .rotation = 0}, {.test_name = "TestWithOffset21", .base_offset = 21, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.25}, + .expected_cell_1_5 = 0.25, + .rotation = 0}, {.test_name = "NoObjectNormalization", .base_offset = 0, .object_normalization_origin_offset = -1, .expected_cell_0_2 = 0.021f, - .expected_cell_1_5 = 0.052f}}), + .expected_cell_1_5 = 0.052f, + .rotation = 0}, + {.test_name = "TestWithRotation", + .base_offset = 0, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.1f, + .expected_cell_1_5 = -0.25f, + .rotation = M_PI / 2.0}}), [](const testing::TestParamInfo< LandmarksWorld3dToMatrixCalculatorTest::ParamType>& info) { return info.param.test_name; diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc index e0d1473c2..333edb6fb 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "absl/memory/memory.h" @@ -27,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" @@ -62,6 +64,8 @@ constexpr char kHandGestureSubgraphTypeName[] = constexpr char kImageTag[] = "IMAGE"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kHandGesturesStreamName[] = "hand_gestures"; constexpr char kHandednessTag[] = "HANDEDNESS"; @@ -72,6 +76,31 @@ constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; constexpr int kMicroSecondsPerMilliSecond = 1000; +// Returns a NormalizedRect filling the whole image. If input is present, its +// rotation is set in the returned NormalizedRect and a check is performed to +// make sure no region-of-interest was provided. Otherwise, rotation is set to +// 0. +absl::StatusOr FillNormalizedRect( + std::optional normalized_rect) { + NormalizedRect result; + if (normalized_rect.has_value()) { + result = *normalized_rect; + } + bool has_coordinates = result.has_x_center() || result.has_y_center() || + result.has_width() || result.has_height(); + if (has_coordinates) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GestureRecognizer does not support region-of-interest.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + result.set_x_center(0.5); + result.set_y_center(0.5); + result.set_width(1); + result.set_height(1); + return result; +} + // Creates a MediaPipe graph config that contains a subgraph node of // "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running // in the live stream mode, a "FlowLimiterCalculator" will be added to limit the @@ -83,6 +112,7 @@ CalculatorGraphConfig CreateGraphConfig( auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName); subgraph.GetOptions().Swap(options.get()); graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >> graph.Out(kHandGesturesTag); subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >> @@ -93,10 +123,11 @@ CalculatorGraphConfig CreateGraphConfig( graph.Out(kHandWorldLandmarksTag); subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator(graph, subgraph, {kImageTag}, - kHandGesturesTag); + return tasks::core::AddFlowLimiterCalculator( + graph, subgraph, {kImageTag, kNormRectTag}, kHandGesturesTag); } graph.In(kImageTag) >> subgraph.In(kImageTag); + graph.In(kNormRectTag) >> subgraph.In(kNormRectTag); return graph.GetConfig(); } @@ -216,16 +247,22 @@ absl::StatusOr> GestureRecognizer::Create( } absl::StatusOr GestureRecognizer::Recognize( - mediapipe::Image image) { + mediapipe::Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - ASSIGN_OR_RETURN(auto output_packets, - ProcessImageData({{kImageInStreamName, - MakePacket(std::move(image))}})); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + FillNormalizedRect(image_processing_options)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); if (output_packets[kHandGesturesStreamName].IsEmpty()) { return {{{}, {}, {}, {}}}; } @@ -245,18 +282,24 @@ absl::StatusOr GestureRecognizer::Recognize( } absl::StatusOr GestureRecognizer::RecognizeForVideo( - mediapipe::Image image, int64 timestamp_ms) { + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + FillNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); if (output_packets[kHandGesturesStreamName].IsEmpty()) { return {{{}, {}, {}, {}}}; @@ -276,17 +319,23 @@ absl::StatusOr GestureRecognizer::RecognizeForVideo( }; } -absl::Status GestureRecognizer::RecognizeAsync(mediapipe::Image image, - int64 timestamp_ms) { +absl::Status GestureRecognizer::RecognizeAsync( + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + FillNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h index 53b824e25..750a99797 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -17,11 +17,13 @@ limitations under the License. #define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ #include +#include #include "absl/status/statusor.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" @@ -57,7 +59,7 @@ struct GestureRecognizerOptions { int num_hands = 1; // The minimum confidence score for the hand detection to be considered - // successfully. + // successful. float min_hand_detection_confidence = 0.5; // The minimum confidence score of hand presence score in the hand landmark @@ -65,11 +67,11 @@ struct GestureRecognizerOptions { float min_hand_presence_confidence = 0.5; // The minimum confidence score for the hand tracking to be considered - // successfully. + // successful. float min_tracking_confidence = 0.5; // The minimum confidence score for the gestures to be considered - // successfully. If < 0, the gesture confidence thresholds in the model + // successful. If < 0, the gesture confidence thresholds in the model // metadata are used. // TODO Note this option is subject to change, after scoring // merging calculator is implemented. @@ -93,6 +95,13 @@ struct GestureRecognizerOptions { // Inputs: // Image // - The image that gesture recognition runs on. +// std::optional +// - If provided, can be used to specify the rotation to apply to the image +// before performing gesture recognition, by setting its 'rotation' field +// in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note +// that specifying a region-of-interest using the 'x_center', 'y_center', +// 'width' and 'height' fields is NOT supported and will result in an +// invalid argument error being returned. // Outputs: // GestureRecognitionResult // - The hand gesture recognition results. @@ -122,12 +131,23 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // // image - mediapipe::Image // Image to perform hand gesture recognition on. + // imageProcessingOptions - std::optional + // If provided, can be used to specify the rotation to apply to the image + // before performing classification, by setting its 'rotation' field in + // radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note that + // specifying a region-of-interest using the 'x_center', 'y_center', 'width' + // and 'height' fields is NOT supported and will result in an invalid + // argument error being returned. // // The image can be of any size with format RGB or RGBA. // TODO: Describes how the input image will be preprocessed // after the yuv support is implemented. + // TODO: use an ImageProcessingOptions struct instead of + // NormalizedRect. absl::StatusOr Recognize( - Image image); + Image image, + std::optional image_processing_options = + std::nullopt); // Performs gesture recognition on the provided video frame. // Only use this method when the GestureRecognizer is created with the video @@ -137,7 +157,9 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. absl::StatusOr - RecognizeForVideo(Image image, int64 timestamp_ms); + RecognizeForVideo(Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Sends live image data to perform gesture recognition, and the results will // be available via the "result_callback" provided in the @@ -157,7 +179,9 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status RecognizeAsync(Image image, int64 timestamp_ms); + absl::Status RecognizeAsync(Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the GestureRecognizer when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index b4f2af4d8..e02eadde8 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/utils.h" @@ -53,6 +54,7 @@ using ::mediapipe::tasks::vision::hand_landmarker::proto:: HandLandmarkerGraphOptions; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kHandednessTag[] = "HANDEDNESS"; @@ -76,6 +78,9 @@ struct GestureRecognizerOutputs { // Inputs: // IMAGE - Image // Image to perform hand gesture recognition on. +// NORM_RECT - NormalizedRect +// Describes image rotation and region of image to perform landmarks +// detection on. // // Outputs: // HAND_GESTURES - std::vector @@ -93,13 +98,15 @@ struct GestureRecognizerOutputs { // IMAGE - mediapipe::Image // The image that gesture recognizer runs on and has the pixel data stored // on the target storage (CPU vs GPU). -// +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. // // Example: // node { // calculator: // "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" // input_stream: "IMAGE:image_in" +// input_stream: "NORM_RECT:norm_rect" // output_stream: "HAND_GESTURES:hand_gestures" // output_stream: "LANDMARKS:hand_landmarks" // output_stream: "WORLD_LANDMARKS:world_hand_landmarks" @@ -132,7 +139,8 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, BuildGestureRecognizerGraph( *sc->MutableOptions(), - graph[Input(kImageTag)], graph)); + graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); hand_gesture_recognition_output.gesture >> graph[Output>(kHandGesturesTag)]; hand_gesture_recognition_output.handedness >> @@ -148,7 +156,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { private: absl::StatusOr BuildGestureRecognizerGraph( GestureRecognizerGraphOptions& graph_options, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { auto& image_property = graph.AddNode("ImagePropertiesCalculator"); image_in >> image_property.In("IMAGE"); auto image_size = image_property.Out("SIZE"); @@ -162,6 +170,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { graph_options.mutable_hand_landmarker_graph_options()); image_in >> hand_landmarker_graph.In(kImageTag); + norm_rect_in >> hand_landmarker_graph.In(kNormRectTag); auto hand_landmarks = hand_landmarker_graph[Output>( kLandmarksTag)]; @@ -187,6 +196,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph { hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag); handedness >> hand_gesture_subgraph.In(kHandednessTag); image_size >> hand_gesture_subgraph.In(kImageSizeTag); + norm_rect_in >> hand_gesture_subgraph.In(kNormRectTag); hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag); auto hand_gestures = hand_gesture_subgraph[Output>( diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 8d7e0bc07..4bbe94974 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" @@ -57,6 +58,7 @@ constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; constexpr char kHandGesturesTag[] = "HAND_GESTURES"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; @@ -92,6 +94,9 @@ Source> ConvertMatrixToTensor(Source matrix, // Detected hand landmarks in world coordinates. // IMAGE_SIZE - std::pair // The size of image from which the landmarks detected from. +// NORM_RECT - NormalizedRect +// NormalizedRect whose 'rotation' field is used to rotate the +// landmarks before processing them. // // Outputs: // HAND_GESTURES - ClassificationList @@ -106,6 +111,7 @@ Source> ConvertMatrixToTensor(Source matrix, // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" // input_stream: "IMAGE_SIZE:image_size" +// input_stream: "NORM_RECT:norm_rect" // output_stream: "HAND_GESTURES:hand_gestures" // options { // [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext] @@ -133,7 +139,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { graph[Input(kHandednessTag)], graph[Input(kLandmarksTag)], graph[Input(kWorldLandmarksTag)], - graph[Input>(kImageSizeTag)], graph)); + graph[Input>(kImageSizeTag)], + graph[Input(kNormRectTag)], graph)); hand_gestures >> graph[Output(kHandGesturesTag)]; return graph.GetConfig(); } @@ -145,7 +152,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { Source handedness, Source hand_landmarks, Source hand_world_landmarks, - Source> image_size, Graph& graph) { + Source> image_size, Source norm_rect, + Graph& graph) { // Converts the ClassificationList to a matrix. auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator"); handedness >> handedness_to_matrix.In(kHandednessTag); @@ -166,6 +174,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { landmarks_options; hand_landmarks >> hand_landmarks_to_matrix.In(kLandmarksTag); image_size >> hand_landmarks_to_matrix.In(kImageSizeTag); + norm_rect >> hand_landmarks_to_matrix.In(kNormRectTag); auto hand_landmarks_matrix = hand_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; @@ -181,6 +190,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { hand_world_landmarks >> hand_world_landmarks_to_matrix.In(kWorldLandmarksTag); image_size >> hand_world_landmarks_to_matrix.In(kImageSizeTag); + norm_rect >> hand_world_landmarks_to_matrix.In(kNormRectTag); auto hand_world_landmarks_matrix = hand_world_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; @@ -239,6 +249,9 @@ REGISTER_MEDIAPIPE_GRAPH( // A vector hand landmarks in world coordinates. // IMAGE_SIZE - std::pair // The size of image from which the landmarks detected from. +// NORM_RECT - NormalizedRect +// NormalizedRect whose 'rotation' field is used to rotate the +// landmarks before processing them. // HAND_TRACKING_IDS - std::vector // A vector of the tracking ids of the hands. The tracking id is the vector // index corresponding to the same hand if the graph runs multiple times. @@ -257,6 +270,7 @@ REGISTER_MEDIAPIPE_GRAPH( // input_stream: "LANDMARKS:landmarks" // input_stream: "WORLD_LANDMARKS:world_landmarks" // input_stream: "IMAGE_SIZE:image_size" +// input_stream: "NORM_RECT:norm_rect" // input_stream: "HAND_TRACKING_IDS:hand_tracking_ids" // output_stream: "HAND_GESTURES:hand_gestures" // options { @@ -283,6 +297,7 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { graph[Input>(kLandmarksTag)], graph[Input>(kWorldLandmarksTag)], graph[Input>(kImageSizeTag)], + graph[Input(kNormRectTag)], graph[Input>(kHandTrackingIdsTag)], graph)); multi_hand_gestures >> graph[Output>(kHandGesturesTag)]; @@ -296,18 +311,20 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { Source> multi_handedness, Source> multi_hand_landmarks, Source> multi_hand_world_landmarks, - Source> image_size, + Source> image_size, Source norm_rect, Source> multi_hand_tracking_ids, Graph& graph) { auto& begin_loop_int = graph.AddNode("BeginLoopIntCalculator"); image_size >> begin_loop_int.In(kCloneTag)[0]; - multi_handedness >> begin_loop_int.In(kCloneTag)[1]; - multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[2]; - multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[3]; + norm_rect >> begin_loop_int.In(kCloneTag)[1]; + multi_handedness >> begin_loop_int.In(kCloneTag)[2]; + multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[3]; + multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[4]; multi_hand_tracking_ids >> begin_loop_int.In(kIterableTag); auto image_size_clone = begin_loop_int.Out(kCloneTag)[0]; - auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[1]; - auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[2]; - auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[3]; + auto norm_rect_clone = begin_loop_int.Out(kCloneTag)[1]; + auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[2]; + auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[3]; + auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[4]; auto hand_tracking_id = begin_loop_int.Out(kItemTag); auto batch_end = begin_loop_int.Out(kBatchEndTag); @@ -341,6 +358,7 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { hand_world_landmarks >> hand_gesture_recognizer_graph.In(kWorldLandmarksTag); image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); + norm_rect_clone >> hand_gesture_recognizer_graph.In(kNormRectTag); auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); auto& end_loop_classification_lists = diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 433a30471..71cef6270 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -32,7 +32,7 @@ cc_library( "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", "//mediapipe/calculators/util:detection_label_id_to_text_calculator", "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", - "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:non_max_suppression_calculator", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 8573d718f..e876d7d09 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -58,6 +58,7 @@ using ::mediapipe::tasks::vision::hand_detector::proto:: HandDetectorGraphOptions; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kHandRectsTag[] = "HAND_RECTS"; constexpr char kPalmRectsTag[] = "PALM_RECTS"; @@ -148,6 +149,9 @@ void ConfigureRectTransformationCalculator( // Inputs: // IMAGE - Image // Image to perform detection on. +// NORM_RECT - NormalizedRect +// Describes image rotation and region of image to perform detection +// on. // // Outputs: // PALM_DETECTIONS - std::vector @@ -159,11 +163,14 @@ void ConfigureRectTransformationCalculator( // IMAGE - Image // The input image that the hand detector runs on and has the pixel data // stored on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. // // Example: // node { // calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" // input_stream: "IMAGE:image" +// input_stream: "NORM_RECT:norm_rect" // output_stream: "PALM_DETECTIONS:palm_detections" // output_stream: "HAND_RECTS:hand_rects_from_palm_detections" // output_stream: "PALM_RECTS:palm_rects" @@ -189,11 +196,11 @@ class HandDetectorGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN( - auto hand_detection_outs, - BuildHandDetectionSubgraph(sc->Options(), - *model_resources, - graph[Input(kImageTag)], graph)); + ASSIGN_OR_RETURN(auto hand_detection_outs, + BuildHandDetectionSubgraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); hand_detection_outs.palm_detections >> graph[Output>(kPalmDetectionsTag)]; hand_detection_outs.hand_rects >> @@ -216,7 +223,7 @@ class HandDetectorGraph : public core::ModelTaskGraph { absl::StatusOr BuildHandDetectionSubgraph( const HandDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio // unchanged. auto& preprocessing = @@ -233,8 +240,9 @@ class HandDetectorGraph : public core::ModelTaskGraph { &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); + norm_rect_in >> preprocessing.In("NORM_RECT"); auto preprocessed_tensors = preprocessing.Out("TENSORS"); - auto letterbox_padding = preprocessing.Out("LETTERBOX_PADDING"); + auto matrix = preprocessing.Out("MATRIX"); auto image_size = preprocessing.Out("IMAGE_SIZE"); // Adds SSD palm detection model. @@ -278,17 +286,12 @@ class HandDetectorGraph : public core::ModelTaskGraph { nms_detections >> detection_label_id_to_text.In(""); auto detections_with_text = detection_label_id_to_text.Out(""); - // Adjusts detection locations (already normalized to [0.f, 1.f]) on the - // letterboxed image (after image transformation with the FIT scale mode) to - // the corresponding locations on the same image with the letterbox removed - // (the input image to the graph before image transformation). - auto& detection_letterbox_removal = - graph.AddNode("DetectionLetterboxRemovalCalculator"); - detections_with_text >> detection_letterbox_removal.In("DETECTIONS"); - letterbox_padding >> detection_letterbox_removal.In("LETTERBOX_PADDING"); + // Projects detections back into the input image coordinates system. + auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); + detections_with_text >> detection_projection.In("DETECTIONS"); + matrix >> detection_projection.In("PROJECTION_MATRIX"); auto palm_detections = - detection_letterbox_removal[Output>( - "DETECTIONS")]; + detection_projection[Output>("DETECTIONS")]; // Converts each palm detection into a rectangle (normalized by image size) // that encloses the palm and is rotated such that the line connecting diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index 11cfc3026..cbbc0e193 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -75,13 +76,18 @@ using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; constexpr char kTestRightHandsImage[] = "right_hands.jpg"; +constexpr char kTestRightHandsRotatedImage[] = "right_hands_rotated.jpg"; constexpr char kTestModelResourcesTag[] = "test_model_resources"; constexpr char kOneHandResultFile[] = "hand_detector_result_one_hand.pbtxt"; +constexpr char kOneHandRotatedResultFile[] = + "hand_detector_result_one_hand_rotated.pbtxt"; constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageName[] = "image"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmDetectionsName[] = "palm_detections"; constexpr char kHandRectsTag[] = "HAND_RECTS"; @@ -117,6 +123,8 @@ absl::StatusOr> CreateTaskRunner( graph[Input(kImageTag)].SetName(kImageName) >> hand_detection.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + hand_detection.In(kNormRectTag); hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >> graph[Output>(kPalmDetectionsTag)]; @@ -142,6 +150,9 @@ struct TestParams { std::string hand_detection_model_name; // The filename of test image. std::string test_image_name; + // The rotation to apply to the test image before processing, in radians + // counter-clockwise. + float rotation; // The number of maximum detected hands. int num_hands; // The expected hand detector result. @@ -154,14 +165,22 @@ TEST_P(HandDetectionTest, DetectTwoHands) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, GetParam().test_image_name))); + NormalizedRect input_norm_rect; + input_norm_rect.set_rotation(GetParam().rotation); + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(GetParam().hand_detection_model_name)); MP_ASSERT_OK_AND_ASSIGN( auto task_runner, CreateTaskRunner(*model_resources, kPalmDetectionModel, GetParam().num_hands)); - auto output_packets = - task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); MP_ASSERT_OK(output_packets); const std::vector& palm_detections = (*output_packets)[kPalmDetectionsName].Get>(); @@ -188,15 +207,24 @@ INSTANTIATE_TEST_SUITE_P( Values(TestParams{.test_name = "DetectOneHand", .hand_detection_model_name = kPalmDetectionModel, .test_image_name = kTestRightHandsImage, + .rotation = 0, .num_hands = 1, .expected_result = GetExpectedHandDetectorResult(kOneHandResultFile)}, TestParams{.test_name = "DetectTwoHands", .hand_detection_model_name = kPalmDetectionModel, .test_image_name = kTestRightHandsImage, + .rotation = 0, .num_hands = 2, .expected_result = - GetExpectedHandDetectorResult(kTwoHandsResultFile)}), + GetExpectedHandDetectorResult(kTwoHandsResultFile)}, + TestParams{.test_name = "DetectOneHandWithRotation", + .hand_detection_model_name = kPalmDetectionModel, + .test_image_name = kTestRightHandsRotatedImage, + .rotation = M_PI / 2.0f, + .num_hands = 1, + .expected_result = GetExpectedHandDetectorResult( + kOneHandRotatedResultFile)}), [](const TestParamInfo& info) { return info.param.test_name; }); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index e8a832bbc..9090fc7b3 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -91,10 +91,14 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD index 3b82153eb..f45681fb3 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD @@ -57,7 +57,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components/containers:landmarks_detection", + "//mediapipe/tasks/cc/components/containers:rect", "//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder", "//mediapipe/tasks/cc/vision/utils:landmarks_utils", "@com_google_absl//absl/algorithm:container", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc index 8920ea0cb..5a5baa50e 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -34,7 +34,7 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h" #include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" @@ -44,7 +44,7 @@ namespace { using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::containers::Bound; +using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::vision::utils::CalculateIOU; using ::mediapipe::tasks::vision::utils::DuplicatesFinder; @@ -126,7 +126,7 @@ absl::StatusOr HandBaselineDistance( return distance; } -Bound CalculateBound(const NormalizedLandmarkList& list) { +Rect CalculateBound(const NormalizedLandmarkList& list) { constexpr float kMinInitialValue = std::numeric_limits::max(); constexpr float kMaxInitialValue = std::numeric_limits::lowest(); @@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder { const int num = multi_landmarks.size(); std::vector baseline_distances; baseline_distances.reserve(num); - std::vector bounds; + std::vector bounds; bounds.reserve(num); for (const NormalizedLandmarkList& list : multi_landmarks) { ASSIGN_OR_RETURN(const float baseline_distance, diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index ab5a453c5..3fbe38c1c 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -29,10 +29,14 @@ limitations under the License. #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" @@ -50,6 +54,8 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::utils::DisallowIf; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; using ::mediapipe::tasks::vision::hand_detector::proto:: HandDetectorGraphOptions; using ::mediapipe::tasks::vision::hand_landmarker::proto:: @@ -58,6 +64,7 @@ using ::mediapipe::tasks::vision::hand_landmarker::proto:: HandLandmarksDetectorGraphOptions; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; @@ -65,6 +72,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmRectsTag[] = "PALM_RECTS"; constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator"; +constexpr char kHandDetectorTFLiteName[] = "hand_detector.tflite"; +constexpr char kHandLandmarksDetectorTFLiteName[] = + "hand_landmarks_detector.tflite"; struct HandLandmarkerOutputs { Source> landmark_lists; @@ -76,6 +86,27 @@ struct HandLandmarkerOutputs { Source image; }; +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + HandLandmarkerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto hand_detector_file, + resources.GetModelFile(kHandDetectorTFLiteName)); + SetExternalFile(hand_detector_file, + options->mutable_hand_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset(), + is_copy); + ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, + resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); + SetExternalFile(hand_landmarks_detector_file, + options->mutable_hand_landmarks_detector_graph_options() + ->mutable_base_options() + ->mutable_model_asset(), + is_copy); + return absl::OkStatus(); +} + } // namespace // A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand @@ -92,6 +123,9 @@ struct HandLandmarkerOutputs { // Inputs: // IMAGE - Image // Image to perform hand landmarks detection on. +// NORM_RECT - NormalizedRect +// Describes image rotation and region of image to perform landmarks +// detection on. // // Outputs: // LANDMARKS: - std::vector @@ -110,11 +144,14 @@ struct HandLandmarkerOutputs { // IMAGE - Image // The input image that the hand landmarker runs on and has the pixel data // stored on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. // // Example: // node { // calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" // input_stream: "IMAGE:image_in" +// input_stream: "NORM_RECT:norm_rect" // output_stream: "LANDMARKS:hand_landmarks" // output_stream: "WORLD_LANDMARKS:world_hand_landmarks" // output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" @@ -154,10 +191,25 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; - ASSIGN_OR_RETURN( - auto hand_landmarker_outputs, - BuildHandLandmarkerGraph(sc->Options(), - graph[Input(kImageTag)], graph)); + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources(sc)); + // Copies the file content instead of passing the pointer of file in + // memory if the subgraph model resource service is not available. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } + ASSIGN_OR_RETURN(auto hand_landmarker_outputs, + BuildHandLandmarkerGraph( + sc->Options(), + graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); hand_landmarker_outputs.landmark_lists >> graph[Output>(kLandmarksTag)]; hand_landmarker_outputs.world_landmark_lists >> @@ -196,7 +248,7 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { // graph: the mediapipe graph instance to be updated. absl::StatusOr BuildHandLandmarkerGraph( const HandLandmarkerGraphOptions& tasks_options, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { const int max_num_hands = tasks_options.hand_detector_graph_options().num_hands(); @@ -214,12 +266,15 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { auto image_for_hand_detector = DisallowIf(image_in, has_enough_hands, graph); + auto norm_rect_in_for_hand_detector = + DisallowIf(norm_rect_in, has_enough_hands, graph); auto& hand_detector = graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); hand_detector.GetOptions().CopyFrom( tasks_options.hand_detector_graph_options()); image_for_hand_detector >> hand_detector.In("IMAGE"); + norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT"); auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); auto& hand_association = graph.AddNode("HandAssociationCalculator"); diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index bce5613ff..08beb1a1b 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include +#include #include "absl/flags/flag.h" #include "absl/status/statusor.h" @@ -65,12 +67,14 @@ using ::testing::proto::Approximately; using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; -constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; -constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite"; +constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task"; constexpr char kLeftHandsImage[] = "left_hands.jpg"; +constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageName[] = "image_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect_in"; constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kLandmarksName[] = "landmarks"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; @@ -85,6 +89,11 @@ constexpr char kExpectedLeftUpHandLandmarksFilename[] = "expected_left_up_hand_landmarks.prototxt"; constexpr char kExpectedLeftDownHandLandmarksFilename[] = "expected_left_down_hand_landmarks.prototxt"; +// Same but for the rotated image. +constexpr char kExpectedLeftUpHandRotatedLandmarksFilename[] = + "expected_left_up_hand_rotated_landmarks.prototxt"; +constexpr char kExpectedLeftDownHandRotatedLandmarksFilename[] = + "expected_left_down_hand_rotated_landmarks.prototxt"; constexpr float kFullModelFractionDiff = 0.03; // percentage constexpr float kAbsMargin = 0.03; @@ -105,21 +114,15 @@ absl::StatusOr> CreateTaskRunner() { "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"); auto& options = hand_landmarker_graph.GetOptions(); - options.mutable_hand_detector_graph_options() - ->mutable_base_options() - ->mutable_model_asset() - ->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel)); - options.mutable_hand_detector_graph_options()->mutable_base_options(); + options.mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, kHandLandmarkerModelBundle)); options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands); - options.mutable_hand_landmarks_detector_graph_options() - ->mutable_base_options() - ->mutable_model_asset() - ->set_file_name( - JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel)); options.set_min_tracking_confidence(kMinTrackingConfidence); graph[Input(kImageTag)].SetName(kImageName) >> hand_landmarker_graph.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + hand_landmarker_graph.In(kNormRectTag); hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >> graph[Output>(kLandmarksTag)]; hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >> @@ -139,9 +142,16 @@ TEST_F(HandLandmarkerTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); - auto output_packets = - task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); const auto& landmarks = (*output_packets)[kLandmarksName] .Get>(); ASSERT_EQ(landmarks.size(), kMaxNumHands); @@ -159,6 +169,38 @@ TEST_F(HandLandmarkerTest, Succeeds) { /*fraction=*/kFullModelFractionDiff)); } +TEST_F(HandLandmarkerTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + kLeftHandsRotatedImage))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + input_norm_rect.set_rotation(M_PI / 2.0); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + const auto& landmarks = (*output_packets)[kLandmarksName] + .Get>(); + ASSERT_EQ(landmarks.size(), kMaxNumHands); + std::vector expected_landmarks = { + GetExpectedLandmarkList(kExpectedLeftUpHandRotatedLandmarksFilename), + GetExpectedLandmarkList(kExpectedLeftDownHandRotatedLandmarksFilename)}; + + EXPECT_THAT(landmarks[0], + Approximately(Partially(EqualsProto(expected_landmarks[0])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); + EXPECT_THAT(landmarks[1], + Approximately(Partially(EqualsProto(expected_landmarks[1])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); +} + } // namespace } // namespace hand_landmarker diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto index c985fc7fa..51e4e129a 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -29,8 +29,8 @@ message HandLandmarkerGraphOptions { extend mediapipe.CalculatorOptions { optional HandLandmarkerGraphOptions ext = 462713202; } - // Base options for configuring MediaPipe Tasks, such as specifying the TfLite - // model file with metadata, accelerator options, etc. + // Base options for configuring MediaPipe Tasks, such as specifying the model + // asset bundle file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; // Options for hand detector graph. diff --git a/mediapipe/tasks/cc/vision/utils/BUILD b/mediapipe/tasks/cc/vision/utils/BUILD index c796798df..fda33bea5 100644 --- a/mediapipe/tasks/cc/vision/utils/BUILD +++ b/mediapipe/tasks/cc/vision/utils/BUILD @@ -94,7 +94,7 @@ cc_library( name = "landmarks_utils", srcs = ["landmarks_utils.cc"], hdrs = ["landmarks_utils.h"], - deps = ["//mediapipe/tasks/cc/components/containers:landmarks_detection"], + deps = ["//mediapipe/tasks/cc/components/containers:rect"], ) cc_test( @@ -103,6 +103,6 @@ cc_test( deps = [ ":landmarks_utils", "//mediapipe/framework/port:gtest_main", - "//mediapipe/tasks/cc/components/containers:landmarks_detection", + "//mediapipe/tasks/cc/components/containers:rect", ], ) diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc index 5ec898f15..2ce9e2454 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -18,15 +18,17 @@ limitations under the License. #include #include +#include "mediapipe/tasks/cc/components/containers/rect.h" + namespace mediapipe::tasks::vision::utils { -using ::mediapipe::tasks::components::containers::Bound; +using ::mediapipe::tasks::components::containers::Rect; -float CalculateArea(const Bound& bound) { - return (bound.right - bound.left) * (bound.bottom - bound.top); +float CalculateArea(const Rect& rect) { + return (rect.right - rect.left) * (rect.bottom - rect.top); } -float CalculateIntersectionArea(const Bound& a, const Bound& b) { +float CalculateIntersectionArea(const Rect& a, const Rect& b) { const float intersection_left = std::max(a.left, b.left); const float intersection_top = std::max(a.top, b.top); const float intersection_right = std::min(a.right, b.right); @@ -36,7 +38,7 @@ float CalculateIntersectionArea(const Bound& a, const Bound& b) { std::max(intersection_right - intersection_left, 0.0); } -float CalculateIOU(const Bound& a, const Bound& b) { +float CalculateIOU(const Rect& a, const Rect& b) { const float area_a = CalculateArea(a); const float area_b = CalculateArea(b); if (area_a <= 0 || area_b <= 0) return 0.0; diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h index b42eae0b6..73114d2ef 100644 --- a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -22,20 +22,20 @@ limitations under the License. #include #include -#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" namespace mediapipe::tasks::vision::utils { // Calculates intersection over union for two bounds. -float CalculateIOU(const components::containers::Bound& a, - const components::containers::Bound& b); +float CalculateIOU(const components::containers::Rect& a, + const components::containers::Rect& b); // Calculates area for face bound -float CalculateArea(const components::containers::Bound& bound); +float CalculateArea(const components::containers::Rect& rect); // Calucates intersection area of two face bounds -float CalculateIntersectionArea(const components::containers::Bound& a, - const components::containers::Bound& b); +float CalculateIntersectionArea(const components::containers::Rect& a, + const components::containers::Rect& b); } // namespace mediapipe::tasks::vision::utils #endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java index 3f96d7779..e45866190 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers/Landmark.java @@ -30,13 +30,13 @@ public abstract class Landmark { return new AutoValue_Landmark(x, y, z, normalized); } - // The x coordniates of the landmark. + // The x coordinates of the landmark. public abstract float x(); - // The y coordniates of the landmark. + // The y coordinates of the landmark. public abstract float y(); - // The z coordniates of the landmark. + // The z coordinates of the landmark. public abstract float z(); // Whether this landmark is normalized with respect to the image size. diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java index 3fa7c2bcc..6a83c7296 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/OutputHandler.java @@ -117,7 +117,7 @@ public class OutputHandler { if (errorListener != null) { errorListener.onError(e); } else { - Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e); + Log.e(TAG, "Error occurs when getting MediaPipe task result. " + e); } } finally { for (Packet packet : packets) { diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD new file mode 100644 index 000000000..1719707d8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -0,0 +1,63 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +# The native library of all MediaPipe text tasks. +cc_binary( + name = "libmediapipe_tasks_text_jni.so", + linkshared = 1, + linkstatic = 1, + deps = [ + "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", + ], +) + +cc_library( + name = "libmediapipe_tasks_text_jni_lib", + srcs = [":libmediapipe_tasks_text_jni.so"], + alwayslink = 1, +) + +android_library( + name = "textclassifier", + srcs = [ + "textclassifier/TextClassificationResult.java", + "textclassifier/TextClassifier.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "textclassifier/AndroidManifest.xml", + deps = [ + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", + "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib", + "//third_party:autovalue", + "@maven//:com_google_guava_guava", + ], +) diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/AndroidManifest.xml new file mode 100644 index 000000000..22de57ae3 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java new file mode 100644 index 000000000..dd9b9a1b3 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassificationResult.java @@ -0,0 +1,103 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textclassifier; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.tasks.components.container.proto.CategoryProto; +import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.components.containers.ClassificationEntry; +import com.google.mediapipe.tasks.components.containers.Classifications; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** Represents the classification results generated by {@link TextClassifier}. */ +@AutoValue +public abstract class TextClassificationResult implements TaskResult { + + /** + * Creates an {@link TextClassificationResult} instance from a {@link + * ClassificationsProto.ClassificationResult} protobuf message. + * + * @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf + * message. + * @param timestampMs a timestamp for this result. + */ + // TODO: consolidate output formats across platforms. + static TextClassificationResult create( + ClassificationsProto.ClassificationResult classificationResult, long timestampMs) { + List classifications = new ArrayList<>(); + for (ClassificationsProto.Classifications classificationsProto : + classificationResult.getClassificationsList()) { + classifications.add(classificationsFromProto(classificationsProto)); + } + return new AutoValue_TextClassificationResult( + timestampMs, Collections.unmodifiableList(classifications)); + } + + @Override + public abstract long timestampMs(); + + /** Contains one set of results per classifier head. */ + @SuppressWarnings("AutoValueImmutableFields") + public abstract List classifications(); + + /** + * Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object. + * + * @param category the {@link CategoryProto.Category} protobuf message to convert. + */ + static Category categoryFromProto(CategoryProto.Category category) { + return Category.create( + category.getScore(), + category.getIndex(), + category.getCategoryName(), + category.getDisplayName()); + } + + /** + * Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link + * ClassificationEntry} object. + * + * @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert. + */ + static ClassificationEntry classificationEntryFromProto( + ClassificationsProto.ClassificationEntry entry) { + List categories = new ArrayList<>(); + for (CategoryProto.Category category : entry.getCategoriesList()) { + categories.add(categoryFromProto(category)); + } + return ClassificationEntry.create(categories, entry.getTimestampMs()); + } + + /** + * Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link + * Classifications} object. + * + * @param classifications the {@link ClassificationsProto.Classifications} protobuf message to + * convert. + */ + static Classifications classificationsFromProto( + ClassificationsProto.Classifications classifications) { + List entries = new ArrayList<>(); + for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) { + entries.add(classificationEntryFromProto(entry)); + } + return Classifications.create( + entries, classifications.getHeadIndex(), classifications.getHeadName()); + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java new file mode 100644 index 000000000..76117d2e4 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/textclassifier/TextClassifier.java @@ -0,0 +1,253 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textclassifier; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto; +import com.google.mediapipe.tasks.components.processors.ClassifierOptions; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.text.textclassifier.proto.TextClassifierGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Performs classification on text. + * + *

This API expects a TFLite model with (optional) TFLite Model Metadata that contains + * the mandatory (described below) input tensors, output tensor, and the optional (but recommended) + * label items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor. + * + *

Metadata is required for models with int32 input tensors because it contains the input process + * unit for the model's Tokenizer. No metadata is required for models with string input tensors. + * + *

    + *
  • Input tensors + *
      + *
    • Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x + * bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input + * signature requires a Bert Tokenizer process unit in the model metadata. + *
    • Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x + * max_seq_len]} representing the input ids. This input signature requires a Regex + * Tokenizer process unit in the model metadata. + *
    • Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code + * [1]} containing the input string. + *
    + *
  • At least one output tensor ({@code kTfLiteFloat32}/{@code kBool}) with: + *
      + *
    • {@code N} classes and shape {@code [1 x N]} + *
    • optional (but recommended) label map(s) as AssociatedFile-s with type + * TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if + * any) is used to fill the {@code class_name} field of the results. The {@code + * display_name} field is filled from the AssociatedFile (if any) whose locale matches + * the {@code display_names_locale} field of the {@code TextClassifierOptions} used at + * creation time ("en" by default, i.e. English). If none of these are available, only + * the {@code index} field of the results will be filled. + *
    + *
+ */ +public final class TextClassifier implements AutoCloseable { + private static final String TAG = TextClassifier.class.getSimpleName(); + private static final String TEXT_IN_STREAM_NAME = "text_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME)); + + @SuppressWarnings("ConstantCaseForConstants") + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("CLASSIFICATION_RESULT:classification_result_out")); + + private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + private final TaskRunner runner; + + static { + System.loadLibrary("mediapipe_tasks_text_jni"); + ProtoUtil.registerTypeName( + ClassificationsProto.ClassificationResult.class, + "mediapipe.tasks.components.containers.proto.ClassificationResult"); + } + + /** + * Creates a {@link TextClassifier} instance from a model file and the default {@link + * TextClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the text model with metadata in the assets. + * @throws MediaPipeException if there is is an error during {@link TextClassifier} creation. + */ + public static TextClassifier createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link TextClassifier} instance from a model file and the default {@link + * TextClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the text model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link TextClassifier} creation. + */ + public static TextClassifier createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link TextClassifier} instance from {@link TextClassifierOptions}. + * + * @param context an Android {@link Context}. + * @param options a {@link TextClassifierOptions} instance. + * @throws MediaPipeException if there is an error during {@link TextClassifier} creation. + */ + public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) { + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public TextClassificationResult convertToTaskResult(List packets) { + try { + return TextClassificationResult.create( + PacketGetter.getProto( + packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), + ClassificationsProto.ClassificationResult.getDefaultInstance()), + packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); + } catch (IOException e) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); + } + } + + @Override + public Void convertToTaskInput(List packets) { + return null; + } + }); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(options) + .setEnableFlowLimiting(false) + .build(), + handler); + return new TextClassifier(runner); + } + + /** + * Constructor to initialize a {@link TextClassifier} from a {@link TaskRunner}. + * + * @param runner a {@link TaskRunner}. + */ + private TextClassifier(TaskRunner runner) { + this.runner = runner; + } + + /** + * Performs classification on the input text. + * + * @param inputText a {@link String} for processing. + */ + public TextClassificationResult classify(String inputText) { + Map inputPackets = new HashMap<>(); + inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText)); + return (TextClassificationResult) runner.process(inputPackets); + } + + /** Closes and cleans up the {@link TextClassifier}. */ + @Override + public void close() { + runner.close(); + } + + /** Options for setting up a {@link TextClassifier}. */ + @AutoValue + public abstract static class TextClassifierOptions extends TaskOptions { + + /** Builder for {@link TextClassifierOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the text classifier task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the optional {@link ClassifierOptions} controling classification behavior, such as + * score threshold, number of results, etc. + */ + public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions); + + public abstract TextClassifierOptions build(); + } + + abstract BaseOptions baseOptions(); + + abstract Optional classifierOptions(); + + public static Builder builder() { + return new AutoValue_TextClassifier_TextClassifierOptions.Builder(); + } + + /** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder = + TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + if (classifierOptions().isPresent()) { + taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto()); + } + return CalculatorOptions.newBuilder() + .setExtension( + TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java index cc8346d80..660645d9c 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/gesturerecognizer/GestureRecognizer.java @@ -15,6 +15,7 @@ package com.google.mediapipe.tasks.vision.gesturerecognizer; import android.content.Context; +import android.graphics.RectF; import android.os.ParcelFileDescriptor; import com.google.auto.value.AutoValue; import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; @@ -71,8 +72,10 @@ import java.util.Optional; public final class GestureRecognizer extends BaseVisionTaskApi { private static final String TAG = GestureRecognizer.class.getSimpleName(); private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; private static final List INPUT_STREAMS = - Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME)); + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); private static final List OUTPUT_STREAMS = Collections.unmodifiableList( Arrays.asList( @@ -205,7 +208,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { * @param runningMode a mediapipe vision task {@link RunningMode}. */ private GestureRecognizer(TaskRunner taskRunner, RunningMode runningMode) { - super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME); + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); } /** @@ -223,7 +226,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public GestureRecognitionResult recognize(Image inputImage) { - return (GestureRecognitionResult) processImageData(inputImage); + // TODO: add proper support for rotations. + return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF()); } /** @@ -244,7 +248,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) { - return (GestureRecognitionResult) processVideoData(inputImage, inputTimestampMs); + // TODO: add proper support for rotations. + return (GestureRecognitionResult) + processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs); } /** @@ -266,7 +272,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi { * @throws MediaPipeException if there is an internal error. */ public void recognizeAsync(Image inputImage, long inputTimestampMs) { - sendLiveStreamData(inputImage, inputTimestampMs); + // TODO: add proper support for rotations. + sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs); } /** Options for setting up an {@link GestureRecognizer}. */ @@ -303,18 +310,18 @@ public final class GestureRecognizer extends BaseVisionTaskApi { /** Sets the maximum number of hands can be detected by the GestureRecognizer. */ public abstract Builder setNumHands(Integer value); - /** Sets minimum confidence score for the hand detection to be considered successfully */ + /** Sets minimum confidence score for the hand detection to be considered successful */ public abstract Builder setMinHandDetectionConfidence(Float value); /** Sets minimum confidence score of hand presence score in the hand landmark detection. */ public abstract Builder setMinHandPresenceConfidence(Float value); - /** Sets the minimum confidence score for the hand tracking to be considered successfully. */ + /** Sets the minimum confidence score for the hand tracking to be considered successful. */ public abstract Builder setMinTrackingConfidence(Float value); /** - * Sets the minimum confidence score for the gestures to be considered successfully. If < 0, - * the gesture confidence threshold=0.5 for the model is used. + * Sets the minimum confidence score for the gestures to be considered successful. If < 0, the + * gesture confidence threshold=0.5 for the model is used. * *

TODO Note this option is subject to change, after scoring merging * calculator is implemented. @@ -433,8 +440,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi { HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder() .setBaseOptions( BaseOptionsProto.BaseOptions.newBuilder() - .setUseStreamMode(runningMode() != RunningMode.IMAGE) - .mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker()))); + .setUseStreamMode(runningMode() != RunningMode.IMAGE)); minTrackingConfidence() .ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence); handLandmarkerGraphOptionsBuilder @@ -465,4 +471,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi { .build(); } } + + /** Creates a RectF covering the full image. */ + private static RectF buildFullImageRectF() { + return new RectF(0, 0, 1, 1); + } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java index 68cae151f..e8e263b71 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imageclassifier/ImageClassifier.java @@ -39,7 +39,6 @@ import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; import com.google.mediapipe.tasks.vision.core.RunningMode; import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto; -import com.google.protobuf.InvalidProtocolBufferException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; @@ -176,7 +175,7 @@ public final class ImageClassifier extends BaseVisionTaskApi { packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX), ClassificationsProto.ClassificationResult.getDefaultInstance()), packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp()); - } catch (InvalidProtocolBufferException e) { + } catch (IOException e) { throw new MediaPipeException( MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage()); } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/AndroidManifest.xml new file mode 100644 index 000000000..2cf08b5fc --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/BUILD new file mode 100644 index 000000000..a7f804c64 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java new file mode 100644 index 000000000..bfca79ced --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/text/textclassifier/TextClassifierTest.java @@ -0,0 +1,154 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.text.textclassifier; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.tasks.components.containers.Category; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.text.textclassifier.TextClassifier.TextClassifierOptions; +import java.util.Arrays; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** Test for {@link TextClassifier}/ */ +@RunWith(AndroidJUnit4.class) +public class TextClassifierTest { + private static final String BERT_MODEL_FILE = "bert_text_classifier.tflite"; + private static final String REGEX_MODEL_FILE = + "test_model_text_classifier_with_regex_tokenizer.tflite"; + private static final String STRING_TO_BOOL_MODEL_FILE = + "test_model_text_classifier_bool_output.tflite"; + private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate"; + private static final String POSITIVE_TEXT = "it's a charming and often affecting journey"; + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonExistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + TextClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), nonExistentFile)); + assertThat(exception).hasMessageThat().contains(nonExistentFile); + } + + @Test + public void create_failsWithMissingOpResolver() throws Exception { + TextClassifierOptions options = + TextClassifierOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(STRING_TO_BOOL_MODEL_FILE).build()) + .build(); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + TextClassifier.createFromOptions( + ApplicationProvider.getApplicationContext(), options)); + // TODO: Make MediaPipe InferenceCalculator report the detailed. + // interpreter errors (e.g., "Encountered unresolved custom op"). + assertThat(exception) + .hasMessageThat() + .contains("interpreter_builder(&interpreter) == kTfLiteOk"); + } + + @Test + public void classify_succeedsWithBert() throws Exception { + TextClassifier textClassifier = + TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE); + TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + assertHasOneHead(negativeResults); + assertCategoriesAre( + negativeResults, + Arrays.asList( + Category.create(0.95630914f, 0, "negative", ""), + Category.create(0.04369091f, 1, "positive", ""))); + + TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + assertHasOneHead(positiveResults); + assertCategoriesAre( + positiveResults, + Arrays.asList( + Category.create(0.99997187f, 1, "positive", ""), + Category.create(2.8132641E-5f, 0, "negative", ""))); + } + + @Test + public void classify_succeedsWithFileObject() throws Exception { + TextClassifier textClassifier = + TextClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), + TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE)); + TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + assertHasOneHead(negativeResults); + assertCategoriesAre( + negativeResults, + Arrays.asList( + Category.create(0.95630914f, 0, "negative", ""), + Category.create(0.04369091f, 1, "positive", ""))); + + TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + assertHasOneHead(positiveResults); + assertHasOneHead(positiveResults); + assertCategoriesAre( + positiveResults, + Arrays.asList( + Category.create(0.99997187f, 1, "positive", ""), + Category.create(2.8132641E-5f, 0, "negative", ""))); + } + + @Test + public void classify_succeedsWithRegex() throws Exception { + TextClassifier textClassifier = + TextClassifier.createFromFile( + ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE); + TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT); + assertHasOneHead(negativeResults); + assertCategoriesAre( + negativeResults, + Arrays.asList( + Category.create(0.6647746f, 0, "Negative", ""), + Category.create(0.33522537f, 1, "Positive", ""))); + + TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT); + assertHasOneHead(positiveResults); + assertCategoriesAre( + positiveResults, + Arrays.asList( + Category.create(0.5120041f, 0, "Negative", ""), + Category.create(0.48799595f, 1, "Positive", ""))); + } + + private static void assertHasOneHead(TextClassificationResult results) { + assertThat(results.classifications()).hasSize(1); + assertThat(results.classifications().get(0).headIndex()).isEqualTo(0); + assertThat(results.classifications().get(0).headName()).isEqualTo("probability"); + assertThat(results.classifications().get(0).entries()).hasSize(1); + } + + private static void assertCategoriesAre( + TextClassificationResult results, List categories) { + assertThat(results.classifications().get(0).entries().get(0).categories()) + .isEqualTo(categories); + } +} diff --git a/mediapipe/tasks/python/test/BUILD b/mediapipe/tasks/python/test/BUILD index d4ef3a35b..8e5b91cf9 100644 --- a/mediapipe/tasks/python/test/BUILD +++ b/mediapipe/tasks/python/test/BUILD @@ -23,5 +23,9 @@ py_library( testonly = 1, srcs = ["test_utils.py"], srcs_version = "PY3", + visibility = [ + "//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__", + "//mediapipe/tasks:internal", + ], deps = ["//mediapipe/python:_framework_bindings"], ) diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 764b93c91..ffb4760d9 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -35,9 +35,11 @@ mediapipe_files(srcs = [ "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "hand_landmark.task", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", + "left_hands_rotated.jpg", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", "mobilenet_v1_0.25_224_1_default_1.tflite", "mobilenet_v1_0.25_224_1_metadata_1.tflite", @@ -51,7 +53,9 @@ mediapipe_files(srcs = [ "multi_objects_rotated.jpg", "palm_detection_full.tflite", "pointing_up.jpg", + "pointing_up_rotated.jpg", "right_hands.jpg", + "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", "segmentation_input_rotation0.jpg", "selfie_segm_128_128_3.tflite", @@ -64,7 +68,9 @@ mediapipe_files(srcs = [ exports_files( srcs = [ "expected_left_down_hand_landmarks.prototxt", + "expected_left_down_hand_rotated_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt", + "expected_left_up_hand_rotated_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", ], @@ -84,11 +90,14 @@ filegroup( "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "left_hands.jpg", + "left_hands_rotated.jpg", "mozart_square.jpg", "multi_objects.jpg", "multi_objects_rotated.jpg", "pointing_up.jpg", + "pointing_up_rotated.jpg", "right_hands.jpg", + "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", "segmentation_input_rotation0.jpg", "selfie_segm_128_128_3_expected_mask.jpg", @@ -109,6 +118,7 @@ filegroup( "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite", "deeplabv3.tflite", + "hand_landmark.task", "hand_landmark_full.tflite", "hand_landmark_lite.tflite", "mobilenet_v1_0.25_192_quantized_1_default_1.tflite", @@ -129,12 +139,17 @@ filegroup( name = "test_protos", srcs = [ "expected_left_down_hand_landmarks.prototxt", + "expected_left_down_hand_rotated_landmarks.prototxt", "expected_left_up_hand_landmarks.prototxt", + "expected_left_up_hand_rotated_landmarks.prototxt", "expected_right_down_hand_landmarks.prototxt", "expected_right_up_hand_landmarks.prototxt", "hand_detector_result_one_hand.pbtxt", + "hand_detector_result_one_hand_rotated.pbtxt", "hand_detector_result_two_hands.pbtxt", "pointing_up_landmarks.pbtxt", + "pointing_up_rotated_landmarks.pbtxt", "thumb_up_landmarks.pbtxt", + "thumb_up_rotated_landmarks.pbtxt", ], ) diff --git a/mediapipe/tasks/testdata/vision/expected_left_down_hand_rotated_landmarks.prototxt b/mediapipe/tasks/testdata/vision/expected_left_down_hand_rotated_landmarks.prototxt new file mode 100644 index 000000000..3cbf8804f --- /dev/null +++ b/mediapipe/tasks/testdata/vision/expected_left_down_hand_rotated_landmarks.prototxt @@ -0,0 +1,84 @@ +landmark { + x: 0.9259716 + y: 0.18969846 +} +landmark { + x: 0.88135517 + y: 0.28856543 +} +landmark { + x: 0.7600651 + y: 0.3578236 +} +landmark { + x: 0.62631166 + y: 0.40490413 +} +landmark { + x: 0.5374573 + y: 0.45170194 +} +landmark { + x: 0.57372385 + y: 0.29924914 +} +landmark { + x: 0.36731184 + y: 0.33081773 +} +landmark { + x: 0.24132833 + y: 0.34759054 +} +landmark { + x: 0.13690609 + y: 0.35727677 +} +landmark { + x: 0.5535803 + y: 0.2398035 +} +landmark { + x: 0.31834763 + y: 0.24999242 +} +landmark { + x: 0.16748133 + y: 0.25625145 +} +landmark { + x: 0.050747424 + y: 0.25991398 +} +landmark { + x: 0.56593156 + y: 0.1867483 +} +landmark { + x: 0.3543046 + y: 0.17923892 +} +landmark { + x: 0.21360746 + y: 0.17454882 +} +landmark { + x: 0.11110917 + y: 0.17232567 +} +landmark { + x: 0.5948908 + y: 0.14024714 +} +landmark { + x: 0.42692152 + y: 0.11949824 +} +landmark { + x: 0.32239118 + y: 0.106370345 +} +landmark { + x: 0.23672739 + y: 0.09432885 +} diff --git a/mediapipe/tasks/testdata/vision/expected_left_up_hand_rotated_landmarks.prototxt b/mediapipe/tasks/testdata/vision/expected_left_up_hand_rotated_landmarks.prototxt new file mode 100644 index 000000000..42eccbcc5 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/expected_left_up_hand_rotated_landmarks.prototxt @@ -0,0 +1,84 @@ +landmark { + x: 0.06676084 + y: 0.8095678 +} +landmark { + x: 0.11359626 + y: 0.71148247 +} +landmark { + x: 0.23572624 + y: 0.6414506 +} +landmark { + x: 0.37323278 + y: 0.5959156 +} +landmark { + x: 0.46243322 + y: 0.55125874 +} +landmark { + x: 0.4205411 + y: 0.69531494 +} +landmark { + x: 0.62798893 + y: 0.66715276 +} +landmark { + x: 0.7568023 + y: 0.65208924 +} +landmark { + x: 0.86370826 + y: 0.6437276 +} +landmark { + x: 0.445136 + y: 0.75394773 +} +landmark { + x: 0.6787485 + y: 0.745853 +} +landmark { + x: 0.8290694 + y: 0.7412988 +} +landmark { + x: 0.94454145 + y: 0.7384017 +} +landmark { + x: 0.43516788 + y: 0.8082166 +} +landmark { + x: 0.6459554 + y: 0.81768996 +} +landmark { + x: 0.7875173 + y: 0.825062 +} +landmark { + x: 0.89249825 + y: 0.82850707 +} +landmark { + x: 0.40665048 + y: 0.8567925 +} +landmark { + x: 0.57228816 + y: 0.8802181 +} +landmark { + x: 0.6762071 + y: 0.8941581 +} +landmark { + x: 0.76453924 + y: 0.90583205 +} diff --git a/mediapipe/tasks/testdata/vision/hand_detector_result_one_hand_rotated.pbtxt b/mediapipe/tasks/testdata/vision/hand_detector_result_one_hand_rotated.pbtxt new file mode 100644 index 000000000..cec4d6166 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/hand_detector_result_one_hand_rotated.pbtxt @@ -0,0 +1,33 @@ +detections { + label: "Palm" + score: 0.97115 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.5198178 + ymin: 0.6467485 + width: 0.42467535 + height: 0.22546273 + } + } +} +detections { + label: "Palm" + score: 0.96701413 + location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.024490356 + ymin: 0.12620124 + width: 0.43832153 + height: 0.23269764 + } + } +} +hand_rects { + x_center: 0.5760683 + y_center: 0.6829921 + height: 0.5862031 + width: 1.1048855 + rotation: -0.8250832 +} diff --git a/mediapipe/tasks/testdata/vision/hand_landmark.task b/mediapipe/tasks/testdata/vision/hand_landmark.task new file mode 100644 index 000000000..b6eedf324 Binary files /dev/null and b/mediapipe/tasks/testdata/vision/hand_landmark.task differ diff --git a/mediapipe/tasks/testdata/vision/pointing_up_rotated_landmarks.pbtxt b/mediapipe/tasks/testdata/vision/pointing_up_rotated_landmarks.pbtxt new file mode 100644 index 000000000..65bb11bc8 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/pointing_up_rotated_landmarks.pbtxt @@ -0,0 +1,223 @@ +classifications { + classification { + score: 1.0 + label: "Left" + display_name: "Left" + } +} + +landmarks { + landmark { + x: 0.25546086 + y: 0.47584262 + z: 1.835341e-07 + } + landmark { + x: 0.3363011 + y: 0.54135 + z: -0.041144375 + } + landmark { + x: 0.4375146 + y: 0.57881975 + z: -0.06807727 + } + landmark { + x: 0.49603376 + y: 0.5263966 + z: -0.09387612 + } + landmark { + x: 0.5022822 + y: 0.4413827 + z: -0.1189948 + } + landmark { + x: 0.5569452 + y: 0.4724485 + z: -0.05138246 + } + landmark { + x: 0.6687125 + y: 0.47918057 + z: -0.09121969 + } + landmark { + x: 0.73666537 + y: 0.48318353 + z: -0.11703273 + } + landmark { + x: 0.7998315 + y: 0.4741413 + z: -0.1386424 + } + landmark { + x: 0.5244063 + y: 0.39292705 + z: -0.061040796 + } + landmark { + x: 0.57215345 + y: 0.41514704 + z: -0.11967233 + } + landmark { + x: 0.4724468 + y: 0.45553637 + z: -0.13287684 + } + landmark { + x: 0.43794966 + y: 0.45210314 + z: -0.13210714 + } + landmark { + x: 0.47838163 + y: 0.33329 + z: -0.07421263 + } + landmark { + x: 0.51081127 + y: 0.35479474 + z: -0.13596693 + } + landmark { + x: 0.42433846 + y: 0.40486792 + z: -0.121291734 + } + landmark { + x: 0.40280548 + y: 0.39977497 + z: -0.09928809 + } + landmark { + x: 0.42269367 + y: 0.2798249 + z: -0.09064263 + } + landmark { + x: 0.45849988 + y: 0.3069861 + z: -0.12894689 + } + landmark { + x: 0.40754712 + y: 0.35153976 + z: -0.109160855 + } + landmark { + x: 0.38855004 + y: 0.3467068 + z: -0.08820164 + } +} + +world_landmarks { + landmark { + x: -0.08568013 + y: 0.016593203 + z: 0.036527164 + } + landmark { + x: -0.0565372 + y: 0.041761592 + z: 0.019493781 + } + landmark { + x: -0.031365488 + y: 0.05031186 + z: 0.0025481891 + } + landmark { + x: -0.008534161 + y: 0.04286737 + z: -0.024755282 + } + landmark { + x: -0.0047254 + y: 0.015748458 + z: -0.035581928 + } + landmark { + x: 0.013083893 + y: 0.024668094 + z: 0.0035934823 + } + landmark { + x: 0.04149521 + y: 0.024621274 + z: -0.0030611698 + } + landmark { + x: 0.06257473 + y: 0.025388625 + z: -0.010340984 + } + landmark { + x: 0.08009179 + y: 0.023082614 + z: -0.03162942 + } + landmark { + x: 0.006135068 + y: 0.000696786 + z: 0.0048212176 + } + landmark { + x: 0.01678449 + y: 0.0067061195 + z: -0.029920919 + } + landmark { + x: -0.008948593 + y: 0.016808286 + z: -0.03755109 + } + landmark { + x: -0.01789449 + y: 0.0153161455 + z: -0.012059977 + } + landmark { + x: -0.0061980113 + y: -0.017872887 + z: -0.002366997 + } + landmark { + x: -0.004643807 + y: -0.0108282855 + z: -0.034515083 + } + landmark { + x: -0.027603384 + y: 0.003529715 + z: -0.033665676 + } + landmark { + x: -0.035679806 + y: 0.0038255951 + z: -0.008094264 + } + landmark { + x: -0.02957782 + y: -0.031701155 + z: -0.008180461 + } + landmark { + x: -0.020741666 + y: -0.02506058 + z: -0.026839724 + } + landmark { + x: -0.0310834 + y: -0.009496164 + z: -0.032422185 + } + landmark { + x: -0.037420202 + y: -0.012883307 + z: -0.017971724 + } +} diff --git a/mediapipe/tasks/testdata/vision/thumb_up_rotated_landmarks.pbtxt b/mediapipe/tasks/testdata/vision/thumb_up_rotated_landmarks.pbtxt new file mode 100644 index 000000000..3636e2e4d --- /dev/null +++ b/mediapipe/tasks/testdata/vision/thumb_up_rotated_landmarks.pbtxt @@ -0,0 +1,223 @@ +classifications { + classification { + score: 1.0 + label: "Left" + display_name: "Left" + } +} + +landmarks { + landmark { + x: 0.3283601 + y: 0.63773525 + z: -3.2280354e-07 + } + landmark { + x: 0.46280807 + y: 0.6339767 + z: -0.06408348 + } + landmark { + x: 0.5831279 + y: 0.57430106 + z: -0.08583106 + } + landmark { + x: 0.6689471 + y: 0.49959752 + z: -0.09886064 + } + landmark { + x: 0.74378216 + y: 0.47357544 + z: -0.09680563 + } + landmark { + x: 0.5233122 + y: 0.41020474 + z: -0.038088404 + } + landmark { + x: 0.5296913 + y: 0.3372598 + z: -0.08874837 + } + landmark { + x: 0.49039274 + y: 0.43994758 + z: -0.102315836 + } + landmark { + x: 0.4824569 + y: 0.47969607 + z: -0.1030014 + } + landmark { + x: 0.4451338 + y: 0.39520803 + z: -0.02177739 + } + landmark { + x: 0.4410001 + y: 0.34107083 + z: -0.07294245 + } + landmark { + x: 0.4162798 + y: 0.46102384 + z: -0.07746907 + } + landmark { + x: 0.43492994 + y: 0.47154287 + z: -0.07404131 + } + landmark { + x: 0.37671578 + y: 0.39535576 + z: -0.016277775 + } + landmark { + x: 0.36978847 + y: 0.34265152 + z: -0.07346253 + } + landmark { + x: 0.3559884 + y: 0.44905427 + z: -0.057693005 + } + landmark { + x: 0.37711847 + y: 0.46414754 + z: -0.03662908 + } + landmark { + x: 0.3142985 + y: 0.3942253 + z: -0.0152847925 + } + landmark { + x: 0.30000874 + y: 0.35543376 + z: -0.046002634 + } + landmark { + x: 0.30002704 + y: 0.42357764 + z: -0.032671776 + } + landmark { + x: 0.31079838 + y: 0.44218025 + z: -0.016200554 + } +} + +world_landmarks { + landmark { + x: -0.030687196 + y: 0.0678545 + z: 0.051061403 + } + landmark { + x: 0.0047719833 + y: 0.06330968 + z: 0.018945374 + } + landmark { + x: 0.039799504 + y: 0.054109577 + z: 0.007930638 + } + landmark { + x: 0.069374144 + y: 0.035063196 + z: 2.2522348e-05 + } + landmark { + x: 0.087818466 + y: 0.018390425 + z: 0.004055788 + } + landmark { + x: 0.02810654 + y: 0.0043561812 + z: -0.0038672548 + } + landmark { + x: 0.025270049 + y: -0.0039896416 + z: -0.032991238 + } + landmark { + x: 0.020414166 + y: 0.006768506 + z: -0.032724563 + } + landmark { + x: 0.016415983 + y: 0.024563588 + z: -0.0058115427 + } + landmark { + x: 0.0038743173 + y: -0.0044466974 + z: 0.0024876352 + } + landmark { + x: 0.0041790796 + y: -0.0115309935 + z: -0.03532454 + } + landmark { + x: -0.0016900161 + y: 0.015519895 + z: -0.03596156 + } + landmark { + x: 0.004309217 + y: 0.01917039 + z: 0.003907912 + } + landmark { + x: -0.016969737 + y: -0.005584497 + z: 0.0034258277 + } + landmark { + x: -0.016737012 + y: -0.01159037 + z: -0.02876696 + } + landmark { + x: -0.018165365 + y: 0.01376111 + z: -0.026835402 + } + landmark { + x: -0.012430167 + y: 0.02064222 + z: -0.00087265146 + } + landmark { + x: -0.043247573 + y: 0.0011161827 + z: 0.0056269006 + } + landmark { + x: -0.038128495 + y: -0.011477032 + z: -0.016374081 + } + landmark { + x: -0.034920715 + y: 0.005510211 + z: -0.029714659 + } + landmark { + x: -0.03815982 + y: 0.011989757 + z: -0.014853194 + } +} diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index b42019a17..2c92293ff 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -151,7 +151,7 @@ def external_files(): http_file( name = "com_google_mediapipe_dummy_gesture_recognizer_task", sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e", - urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665524417056146"], + urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665707319890725"], ) http_file( @@ -166,12 +166,24 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_landmarks.prototxt?generation=1661875720230540"], ) + http_file( + name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt", + sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf", + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"], + ) + http_file( name = "com_google_mediapipe_expected_left_up_hand_landmarks_prototxt", sha256 = "1353ba617c4f048083618587cd23a8a22115f634521c153d4e1bd1ebd4f49dd7", urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_landmarks.prototxt?generation=1661875726008879"], ) + http_file( + name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt", + sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a", + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"], + ) + http_file( name = "com_google_mediapipe_expected_right_down_hand_landmarks_prototxt", sha256 = "f281b745175aaa7f458def6cf4c89521fb56302dd61a05642b3b4a4f237ffaa3", @@ -250,6 +262,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand.pbtxt?generation=1662745351291628"], ) + http_file( + name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt", + sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"], + ) + http_file( name = "com_google_mediapipe_hand_detector_result_two_hands_pbtxt", sha256 = "2589cb08b0ee027dc24649fe597adcfa2156a21d12ea2480f83832714ebdf95f", @@ -268,6 +286,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark_lite.tflite?generation=1661875766398729"], ) + http_file( + name = "com_google_mediapipe_hand_landmark_task", + sha256 = "dd830295598e48e6bbbdf22fd9e69538fa07768106cd9ceb04d5462ca7e38c95", + urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.task?generation=1665707323647357"], + ) + http_file( name = "com_google_mediapipe_hand_recrop_tflite", sha256 = "67d996ce96f9d36fe17d2693022c6da93168026ab2f028f9e2365398d8ac7d5d", @@ -346,6 +370,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands.jpg?generation=1661875796949017"], ) + http_file( + name = "com_google_mediapipe_left_hands_rotated_jpg", + sha256 = "8609c6202bca43a99bbf23fa8e687e49fa525e89481152e4c0987f46d60d7931", + urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands_rotated.jpg?generation=1666037068103465"], + ) + http_file( name = "com_google_mediapipe_mobilebert_embedding_with_metadata_tflite", sha256 = "fa47142dcc6f446168bc672f2df9605b6da5d0c0d6264e9be62870282365b95c", @@ -538,6 +568,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1665174976408451"], ) + http_file( + name = "com_google_mediapipe_pointing_up_rotated_jpg", + sha256 = "50ff66f50281207072a038e5bb6648c43f4aacbfb8204a4d2591868756aaeff1", + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated.jpg?generation=1666037072219697"], + ) + + http_file( + name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt", + sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de", + urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"], + ) + http_file( name = "com_google_mediapipe_pose_detection_tflite", sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", @@ -574,6 +616,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands.jpg?generation=1661875908672404"], ) + http_file( + name = "com_google_mediapipe_right_hands_rotated_jpg", + sha256 = "b3bdf692f0d54b86c8b67e6d1286dd0078fbe6e9dfcd507b187e3bd8b398c0f9", + urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands_rotated.jpg?generation=1666037076873345"], + ) + http_file( name = "com_google_mediapipe_score_calibration_file_meta_json", sha256 = "6a3c305620371f662419a496f75be5a10caebca7803b1e99d8d5d22ba51cda94", @@ -718,6 +766,12 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1665174979747784"], ) + http_file( + name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt", + sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102", + urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"], + ) + http_file( name = "com_google_mediapipe_two_heads_16000_hz_mono_wav", sha256 = "a291a9c22c39bba30138a26915e154a96286ba6ca3b413053123c504a58cce3b",