diff --git a/Dockerfile b/Dockerfile index 462dacbd4..4d6c68e7e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,7 +53,7 @@ RUN pip3 install wheel RUN pip3 install future RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1 RUN pip3 install six==1.14.0 -RUN pip3 install tensorflow==2.2.0 +RUN pip3 install tensorflow RUN pip3 install tf_slim RUN ln -s /usr/bin/python3 /usr/bin/python diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index f951b506d..b20a87467 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`. } ``` +## Graph Options + +It is possible to specify a "graph options" protobuf for a MediaPipe graph +similar to the [`Calculator Options`](calculators.md#calculator-options) +protobuf specified for a MediaPipe calculator. These "graph options" can be +specified where a graph is invoked, and used to populate calculator options and +subgraph options within the graph. + +In a CalculatorGraphConfig, graph options can be specified for a subgraph +exactly like calculator options, as shown below: + +``` +node { + calculator: "FlowLimiterCalculator" + input_stream: "image" + output_stream: "throttled_image" + node_options: { + [type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] { + max_in_flight: 1 + } + } +} + +node { + calculator: "FaceDetectionSubgraph" + input_stream: "IMAGE:throttled_image" + node_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] { + tensor_width: 192 + tensor_height: 192 + } + } +} +``` + +In a CalculatorGraphConfig, graph options can be accepted and used to populate +calculator options, as shown below: + +``` +graph_options: { + [type.googleapis.com/mediapipe.FaceDetectionOptions] {} +} + +node: { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:multi_backend_image" + node_options: { + [type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] { + keep_aspect_ratio: true + border_mode: BORDER_ZERO + } + } + option_value: "output_tensor_width:options/tensor_width" + option_value: "output_tensor_height:options/tensor_height" +} + +node { + calculator: "InferenceCalculator" + node_options: { + [type.googleapis.com/mediapipe.InferenceCalculatorOptions] {} + } + option_value: "delegate:options/delegate" + option_value: "model_path:options/model_path" +} +``` + +In this example, the `FaceDetectionSubgraph` accepts graph option protobuf +`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field +values in the calculator options `ImageToTensorCalculatorOptions` and some field +values in the subgraph options `InferenceCalculatorOptions`. The field values +are defined using the `option_value:` syntax. + +In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and +`option_value:` together define the option values for a calculator such as +`ImageToTensorCalculator`. The `node_options:` field defines a set of literal +constant values using the text protobuf syntax. Each `option_value:` field +defines the value for one protobuf field using information from the enclosing +graph, specifically from field values of the graph options of the enclosing +graph. In the example above, the `option_value:` +`"output_tensor_width:options/tensor_width"` defines the field +`ImageToTensorCalculatorOptions.output_tensor_width` using the value of +`FaceDetectionOptions.tensor_width`. + +The syntax of `option_value:` is similar to the syntax of `input_stream:`. The +syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option +field and the RHS identifies a graph option field. More specifically, the LHS +and RHS each consists of a series of protobuf field names identifying nested +protobuf messages and fields separated by '/'. This is known as the "ProtoPath" +syntax. Nested messages that are referenced in the LHS or RHS must already be +defined in the enclosing protobuf in order to be traversed using +`option_value:`. + ## Cycles diff --git a/docs/getting_started/cpp.md b/docs/getting_started/cpp.md index 8fc091fea..47fb697b0 100644 --- a/docs/getting_started/cpp.md +++ b/docs/getting_started/cpp.md @@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow ```bash GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ - --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt + --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt ``` This will open up your webcam as long as it is connected and on. Any errors diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index 26ada44bf..8fbc99829 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -1410,3 +1410,45 @@ cc_library( ], alwayslink = 1, ) + +mediapipe_proto_library( + name = "bypass_calculator_proto", + srcs = ["bypass_calculator.proto"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "bypass_calculator", + srcs = ["bypass_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":bypass_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + +cc_test( + name = "bypass_calculator_test", + srcs = ["bypass_calculator_test.cc"], + deps = [ + ":bypass_calculator", + ":pass_through_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:switch_container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) diff --git a/mediapipe/calculators/core/bypass_calculator.cc b/mediapipe/calculators/core/bypass_calculator.cc new file mode 100644 index 000000000..86dcfc0e1 --- /dev/null +++ b/mediapipe/calculators/core/bypass_calculator.cc @@ -0,0 +1,161 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "mediapipe/calculators/core/bypass_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" +#include "mediapipe/framework/port/status.h" + +namespace mediapipe { +namespace api2 { + +using mediapipe::BypassCalculatorOptions; + +// Defines a "bypass" channel to use in place of a disabled feature subgraph. +// By default, all inputs are discarded and all outputs are ignored. +// Certain input streams can be passed to corresponding output streams +// by specifying them in "pass_input_stream" and "pass_output_stream" options. +// All output streams are updated with timestamp bounds indicating completed +// output. +// +// Note that this calculator is designed for use as a contained_node in a +// SwitchContainer. For this reason, any input and output tags are accepted, +// and stream semantics are specified through BypassCalculatorOptions. +// +// Example config: +// node { +// calculator: "BypassCalculator" +// input_stream: "APPEARANCES:appearances_post_facenet" +// input_stream: "VIDEO:video_frame" +// input_stream: "FEATURE_CONFIG:feature_config" +// input_stream: "ENABLE:gaze_enabled" +// output_stream: "APPEARANCES:analyzed_appearances" +// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" +// node_options: { +// [type.googleapis.com/mediapipe.BypassCalculatorOptions] { +// pass_input_stream: "APPEARANCES" +// pass_output_stream: "APPEARANCES" +// } +// } +// } +// +class BypassCalculator : public Node { + public: + static constexpr mediapipe::api2::Input::Optional kNotNeeded{"N_N_"}; + MEDIAPIPE_NODE_CONTRACT(kNotNeeded); + using IdMap = std::map; + + // Returns the map of passthrough input and output stream ids. + static absl::StatusOr GetPassMap( + const BypassCalculatorOptions& options, const tool::TagMap& input_map, + const tool::TagMap& output_map) { + IdMap result; + auto& input_streams = options.pass_input_stream(); + auto& output_streams = options.pass_output_stream(); + int size = std::min(input_streams.size(), output_streams.size()); + for (int i = 0; i < size; ++i) { + std::pair in_tag, out_tag; + MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i), + &in_tag.first, &in_tag.second)); + MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i), + &out_tag.first, &out_tag.second)); + auto input_id = input_map.GetId(in_tag.first, in_tag.second); + auto output_id = output_map.GetId(out_tag.first, out_tag.second); + result[input_id] = output_id; + } + return result; + } + + // Identifies all specified streams as "Any" packet type. + // Identifies passthrough streams as "Same" packet type. + static absl::Status UpdateContract(CalculatorContract* cc) { + auto options = cc->Options(); + RET_CHECK_EQ(options.pass_input_stream().size(), + options.pass_output_stream().size()); + ASSIGN_OR_RETURN( + auto pass_streams, + GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap())); + std::set pass_out; + for (auto entry : pass_streams) { + pass_out.insert(entry.second); + cc->Inputs().Get(entry.first).SetAny(); + cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first)); + } + for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) { + if (pass_streams.count(id) == 0) { + cc->Inputs().Get(id).SetAny(); + } + } + for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) { + if (pass_out.count(id) == 0) { + cc->Outputs().Get(id).SetAny(); + } + } + return absl::OkStatus(); + } + + // Saves the map of passthrough input and output stream ids. + absl::Status Open(CalculatorContext* cc) override { + auto options = cc->Options(); + ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(), + *cc->Outputs().TagMap())); + return absl::OkStatus(); + } + + // Copies packets between passthrough input and output streams. + // Updates timestamp bounds on all output streams. + absl::Status Process(CalculatorContext* cc) override { + std::set pass_out; + for (auto entry : pass_streams_) { + pass_out.insert(entry.second); + auto& packet = cc->Inputs().Get(entry.first).Value(); + if (packet.Timestamp() == cc->InputTimestamp()) { + cc->Outputs().Get(entry.first).AddPacket(packet); + } + } + Timestamp bound = cc->InputTimestamp().NextAllowedInStream(); + for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) { + if (pass_out.count(id) == 0) { + cc->Outputs().Get(id).SetNextTimestampBound( + std::max(cc->Outputs().Get(id).NextTimestampBound(), bound)); + } + } + return absl::OkStatus(); + } + + // Close all output streams. + absl::Status Close(CalculatorContext* cc) override { + for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) { + cc->Outputs().Get(id).Close(); + } + return absl::OkStatus(); + } + + private: + IdMap pass_streams_; +}; + +MEDIAPIPE_REGISTER_NODE(BypassCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/core/bypass_calculator.proto b/mediapipe/calculators/core/bypass_calculator.proto new file mode 100644 index 000000000..1a273edb2 --- /dev/null +++ b/mediapipe/calculators/core/bypass_calculator.proto @@ -0,0 +1,31 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message BypassCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional BypassCalculatorOptions ext = 481259677; + } + + // Names an input stream or streams to pass through, by "TAG:index". + repeated string pass_input_stream = 1; + + // Names an output stream or streams to pass through, by "TAG:index". + repeated string pass_output_stream = 2; +} diff --git a/mediapipe/calculators/core/bypass_calculator_test.cc b/mediapipe/calculators/core/bypass_calculator_test.cc new file mode 100644 index 000000000..4d1cd8f79 --- /dev/null +++ b/mediapipe/calculators/core/bypass_calculator_test.cc @@ -0,0 +1,302 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { + +namespace { + +// A graph with using a BypassCalculator to pass through and ignore +// most of its inputs and outputs. +constexpr char kTestGraphConfig1[] = R"pb( + type: "AppearancesPassThroughSubgraph" + input_stream: "APPEARANCES:appearances" + input_stream: "VIDEO:video_frame" + input_stream: "FEATURE_CONFIG:feature_config" + output_stream: "APPEARANCES:passthrough_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output" + + node { + calculator: "BypassCalculator" + input_stream: "PASS:appearances" + input_stream: "TRUNCATE:0:video_frame" + input_stream: "TRUNCATE:1:feature_config" + output_stream: "PASS:passthrough_appearances" + output_stream: "TRUNCATE:passthrough_federated_gaze_output" + node_options: { + [type.googleapis.com/mediapipe.BypassCalculatorOptions] { + pass_input_stream: "PASS" + pass_output_stream: "PASS" + } + } + } +)pb"; + +// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel +// for input frames and appearances. +constexpr char kTestGraphConfig2[] = R"pb( + input_stream: "VIDEO_FULL_RES:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "GAZE_ENABLED:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + + node { + calculator: "SwitchContainer" + input_stream: "VIDEO:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "ENABLE:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + options { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { calculator: "AppearancesPassThroughSubgraph" } + } + } + } +)pb"; + +// A graph with using BypassCalculator as a do-nothing channel +// for input frames and appearances. +constexpr char kTestGraphConfig3[] = R"pb( + input_stream: "VIDEO_FULL_RES:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "GAZE_ENABLED:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + + node { + calculator: "SwitchContainer" + input_stream: "VIDEO:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "ENABLE:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + options { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { + calculator: "BypassCalculator" + node_options: { + [type.googleapis.com/mediapipe.BypassCalculatorOptions] { + pass_input_stream: "APPEARANCES" + pass_output_stream: "APPEARANCES" + } + } + } + } + } + } +)pb"; + +// A graph with using BypassCalculator as a disabled-gate +// for input frames and appearances. +constexpr char kTestGraphConfig4[] = R"pb( + input_stream: "VIDEO_FULL_RES:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + input_stream: "GAZE_ENABLED:gaze_enabled" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output" + + node { + calculator: "SwitchContainer" + input_stream: "ENABLE:gaze_enabled" + input_stream: "VIDEO:video_frame" + input_stream: "APPEARANCES:input_appearances" + input_stream: "FEATURE_CONFIG:feature_config" + output_stream: "VIDEO:video_frame_out" + output_stream: "APPEARANCES:analyzed_appearances" + output_stream: "FEATURE_CONFIG:feature_config_out" + options { + [mediapipe.SwitchContainerOptions.ext] { + contained_node: { calculator: "BypassCalculator" } + contained_node: { calculator: "PassThroughCalculator" } + } + } + } +)pb"; + +// Reports packet timestamp and string contents, or """. +std::string DebugString(Packet p) { + return absl::StrCat(p.Timestamp().DebugString(), ":", + p.IsEmpty() ? "" : p.Get()); +} + +// Shows a bypass subgraph that passes through one stream. +TEST(BypassCalculatorTest, SubgraphChannel) { + CalculatorGraphConfig config_1 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig1); + CalculatorGraphConfig config_2 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig2); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {})); + + std::vector analyzed_appearances; + MP_ASSERT_OK(graph.ObserveOutputStream( + "analyzed_appearances", + [&](const Packet& p) { + analyzed_appearances.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + std::vector federated_gaze_output; + MP_ASSERT_OK(graph.ObserveOutputStream( + "federated_gaze_output", + [&](const Packet& p) { + federated_gaze_output.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1")); + EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:")); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// Shows a BypassCalculator that passes through one stream. +TEST(BypassCalculatorTest, CalculatorChannel) { + CalculatorGraphConfig config_3 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig3); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config_3}, {})); + + std::vector analyzed_appearances; + MP_ASSERT_OK(graph.ObserveOutputStream( + "analyzed_appearances", + [&](const Packet& p) { + analyzed_appearances.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + std::vector federated_gaze_output; + MP_ASSERT_OK(graph.ObserveOutputStream( + "federated_gaze_output", + [&](const Packet& p) { + federated_gaze_output.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1")); + EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:")); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +// Shows a BypassCalculator that discards all inputs when ENABLED is false. +TEST(BypassCalculatorTest, GatedChannel) { + CalculatorGraphConfig config_3 = + mediapipe::ParseTextProtoOrDie(kTestGraphConfig4); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize({config_3}, {})); + + std::vector analyzed_appearances; + MP_ASSERT_OK(graph.ObserveOutputStream( + "analyzed_appearances", + [&](const Packet& p) { + analyzed_appearances.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + std::vector video_frame; + MP_ASSERT_OK(graph.ObserveOutputStream( + "video_frame_out", + [&](const Packet& p) { + video_frame.push_back(DebugString(p)); + return absl::OkStatus(); + }, + true)); + MP_ASSERT_OK(graph.StartRun({})); + + // Close the gate. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "gaze_enabled", MakePacket(false).At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send packets at timestamp 200. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f1").At(Timestamp(200)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Only timestamps arrive from the BypassCalculator. + EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:")); + EXPECT_THAT(video_frame, testing::ElementsAre("200:")); + + // Open the gate. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "gaze_enabled", MakePacket(true).At(Timestamp(300)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Send packets at timestamp 300. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "input_appearances", MakePacket("a2").At(Timestamp(300)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "video_frame", MakePacket("v2").At(Timestamp(300)))); + MP_ASSERT_OK(graph.AddPacketToInputStream( + "feature_config", MakePacket("f2").At(Timestamp(300)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + + // Packets arrive from the PassThroughCalculator. + EXPECT_THAT(analyzed_appearances, + testing::ElementsAre("200:", "300:a2")); + EXPECT_THAT(video_frame, testing::ElementsAre("200:", "300:v2")); + + MP_ASSERT_OK(graph.CloseAllInputStreams()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +} // namespace + +} // namespace mediapipe diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 458c5368b..89e2d371c 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -209,11 +209,18 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "rotation_mode_proto", + srcs = ["rotation_mode.proto"], + visibility = ["//visibility:public"], +) + mediapipe_proto_library( name = "image_transformation_calculator_proto", srcs = ["image_transformation_calculator.proto"], visibility = ["//visibility:public"], deps = [ + ":rotation_mode_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/gpu:scale_mode_proto", @@ -238,6 +245,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + ":rotation_mode_cc_proto", ":image_transformation_calculator_cc_proto", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index bc7fd8df7..84697cc62 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "mediapipe/calculators/image/image_transformation_calculator.pb.h" +#include "mediapipe/calculators/image/rotation_mode.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" diff --git a/mediapipe/calculators/image/image_transformation_calculator.proto b/mediapipe/calculators/image/image_transformation_calculator.proto index c90e03be9..739c5bfbb 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.proto +++ b/mediapipe/calculators/image/image_transformation_calculator.proto @@ -16,20 +16,10 @@ syntax = "proto2"; package mediapipe; +import "mediapipe/calculators/image/rotation_mode.proto"; import "mediapipe/framework/calculator.proto"; import "mediapipe/gpu/scale_mode.proto"; -// Counterclockwise rotation. -message RotationMode { - enum Mode { - UNKNOWN = 0; - ROTATION_0 = 1; - ROTATION_90 = 2; - ROTATION_180 = 3; - ROTATION_270 = 4; - } -} - message ImageTransformationCalculatorOptions { extend CalculatorOptions { optional ImageTransformationCalculatorOptions ext = 251952830; diff --git a/mediapipe/calculators/image/rotation_mode.proto b/mediapipe/calculators/image/rotation_mode.proto new file mode 100644 index 000000000..7fa4a8eda --- /dev/null +++ b/mediapipe/calculators/image/rotation_mode.proto @@ -0,0 +1,31 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +option java_package = "com.google.mediapipe.calculator.proto"; +option java_outer_classname = "RotationModeProto"; + +// Counterclockwise rotation. +message RotationMode { + enum Mode { + UNKNOWN = 0; + ROTATION_0 = 1; + ROTATION_90 = 2; + ROTATION_180 = 3; + ROTATION_270 = 4; + } +} diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 93f2dbd06..99b5b3e91 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -161,6 +161,193 @@ cc_test( ], ) +mediapipe_proto_library( + name = "bert_preprocessor_calculator_proto", + srcs = ["bert_preprocessor_calculator.proto"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "bert_preprocessor_calculator", + srcs = ["bert_preprocessor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + ":bert_preprocessor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/tokenizers:tokenizer", + "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "bert_preprocessor_calculator_test", + srcs = ["bert_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"], + linkopts = ["-ldl"], + deps = [ + ":bert_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +mediapipe_proto_library( + name = "regex_preprocessor_calculator_proto", + srcs = ["regex_preprocessor_calculator.proto"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "regex_preprocessor_calculator", + srcs = ["regex_preprocessor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + ":regex_preprocessor_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer", + "//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_test( + name = "regex_preprocessor_calculator_test", + srcs = ["regex_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:text_classifier_models"], + linkopts = ["-ldl"], + deps = [ + ":regex_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:sink", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "text_to_tensor_calculator", + srcs = ["text_to_tensor_calculator.cc"], + visibility = [ + "//mediapipe/framework:mediapipe_internal", + ], + deps = [ + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "text_to_tensor_calculator_test", + srcs = ["text_to_tensor_calculator_test.cc"], + deps = [ + ":text_to_tensor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/tool:options_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "universal_sentence_encoder_preprocessor_calculator", + srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"], + deps = [ + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "universal_sentence_encoder_preprocessor_calculator_test", + srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"], + data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"], + deps = [ + ":universal_sentence_encoder_preprocessor_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "//mediapipe/framework/tool:options_map", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + mediapipe_proto_library( name = "inference_calculator_proto", srcs = ["inference_calculator.proto"], @@ -320,6 +507,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:c_api_types", "@org_tensorflow//tensorflow/lite/core/api:op_resolver", ], ) diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc b/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc new file mode 100644 index 000000000..d464c3929 --- /dev/null +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator.cc @@ -0,0 +1,251 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace api2 { + +using ::mediapipe::tasks::core::FindTensorIndexByMetadataName; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kTokenizerProcessUnitIndex = 0; +constexpr absl::string_view kInputIdsTensorName = "ids"; +constexpr absl::string_view kInputMasksTensorName = "mask"; +constexpr absl::string_view kSegmentIdsTensorName = "segment_ids"; +constexpr absl::string_view kClassifierToken = "[CLS]"; +constexpr absl::string_view kSeparatorToken = "[SEP]"; + +// Preprocesses input text into three int32 input tensors for a BERT model using +// a tokenizer. +// The associated BERT model is expected to contain input tensors with names: +// +// Tensor | Metadata Name +// ---------------- | -------------- +// IDs | "ids" +// Segment IDs | "segment_ids" +// Mask | "mask" +// +// This calculator will return an error if the model does not have three input +// tensors or if the tensors do not have names corresponding to the above +// metadata names in some order. Additional details regarding these input +// tensors are given in the Calculator "Outputs" section below. +// +// This calculator is currently configured for the TextClassifier Task but it +// will eventually be generalized for other Text Tasks. +// TODO: Handle preprocessing for other Text Tasks too. +// +// Inputs: +// TEXT - std::string +// The input text. +// Side Inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the BERT model. Used to determine the order of +// the three input Tensors for the BERT model and to extract the metadata to +// construct the tokenizer. +// +// Outputs: +// TENSORS - std::vector +// Vector containing the three input Tensors for the BERT model: +// (1): the token ids of the tokenized input string. A classifier token +// ("[CLS]") will be prepended to the input tokens and a separator +// token ("[SEP]") will be appended to the input tokens. +// (2): the segment ids, which are all 0 for now but will have different +// values to distinguish between different sentences in the input +// text for other Text tasks. +// (3): the input mask ids, which are 1 at each of the input token indices +// and 0 elsewhere. +// The Tensors will have size equal to the max sequence length for the BERT +// model. +// +// Example: +// node { +// calculator: "BertPreprocessorCalculator" +// input_stream: "TEXT:text" +// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" +// output_stream: "TENSORS:tensors" +// options { +// [mediapipe.BertPreprocessorCalculatorOptions.ext] { +// bert_max_seq_len: 128 +// } +// } +// } +class BertPreprocessorCalculator : public Node { + public: + static constexpr Input kTextIn{"TEXT"}; + static constexpr SideInput kMetadataExtractorSideIn{ + "METADATA_EXTRACTOR"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + + MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut); + + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + std::unique_ptr tokenizer_; + // The max sequence length accepted by the BERT model. + int bert_max_seq_len_ = 2; + // Indices of the three input tensors for the BERT model. They should form the + // set {0, 1, 2}. + int input_ids_tensor_index_ = 0; + int segment_ids_tensor_index_ = 1; + int input_masks_tensor_index_ = 2; + + // Applies `tokenizer_` to the `input_text` to generate a vector of tokens. + // This util prepends "[CLS]" and appends "[SEP]" to the input tokens and + // clips the vector of tokens to have length at most `bert_max_seq_len_`. + std::vector TokenizeInputText(absl::string_view input_text); + // Processes the `input_tokens` to generate the three input tensors for the + // BERT model. + std::vector GenerateInputTensors( + const std::vector& input_tokens); +}; + +absl::Status BertPreprocessorCalculator::UpdateContract( + CalculatorContract* cc) { + const auto& options = + cc->Options(); + RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required"; + RET_CHECK_GE(options.bert_max_seq_len(), 2) + << "bert_max_seq_len must be at least 2"; + return absl::OkStatus(); +} + +absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) { + const ModelMetadataExtractor* metadata_extractor = + &kMetadataExtractorSideIn(cc).Get(); + const tflite::ProcessUnit* tokenizer_metadata = + metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex); + ASSIGN_OR_RETURN(tokenizer_, + tasks::text::tokenizers::CreateTokenizerFromProcessUnit( + tokenizer_metadata, metadata_extractor)); + + auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata(); + input_ids_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kInputIdsTensorName); + segment_ids_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kSegmentIdsTensorName); + input_masks_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kInputMasksTensorName); + absl::flat_hash_set tensor_indices = {input_ids_tensor_index_, + segment_ids_tensor_index_, + input_masks_tensor_index_}; + if (tensor_indices != absl::flat_hash_set({0, 1, 2})) { + return absl::InvalidArgumentError(absl::Substitute( + "Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}", + input_ids_tensor_index_, segment_ids_tensor_index_, + input_masks_tensor_index_)); + } + + const auto& options = + cc->Options(); + bert_max_seq_len_ = options.bert_max_seq_len(); + return absl::OkStatus(); +} + +absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) { + kTensorsOut(cc).Send( + GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get()))); + return absl::OkStatus(); +} + +std::vector BertPreprocessorCalculator::TokenizeInputText( + absl::string_view input_text) { + std::string processed_input = std::string(input_text); + absl::AsciiStrToLower(&processed_input); + + tasks::text::tokenizers::TokenizerResult tokenizer_result = + tokenizer_->Tokenize(processed_input); + + // Offset by 2 to account for [CLS] and [SEP] + int input_tokens_size = + std::min(bert_max_seq_len_, + static_cast(tokenizer_result.subwords.size()) + 2); + std::vector input_tokens; + input_tokens.reserve(input_tokens_size); + input_tokens.push_back(std::string(kClassifierToken)); + for (int i = 0; i < input_tokens_size - 2; ++i) { + input_tokens.push_back(std::move(tokenizer_result.subwords[i])); + } + input_tokens.push_back(std::string(kSeparatorToken)); + return input_tokens; +} + +std::vector BertPreprocessorCalculator::GenerateInputTensors( + const std::vector& input_tokens) { + std::vector input_ids(bert_max_seq_len_, 0); + std::vector segment_ids(bert_max_seq_len_, 0); + std::vector input_masks(bert_max_seq_len_, 0); + // Convert tokens back into ids and set mask + for (int i = 0; i < input_tokens.size(); ++i) { + tokenizer_->LookupId(input_tokens[i], &input_ids[i]); + input_masks[i] = 1; + } + // |<--------bert_max_seq_len_--------->| + // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 + // segment_ids 0 0 0... 0 0 0 0... 0 + // input_masks 1 1 1... 1 1 0 0... 0 + + std::vector input_tensors; + input_tensors.reserve(kNumInputTensorsForBert); + for (int i = 0; i < kNumInputTensorsForBert; ++i) { + input_tensors.push_back( + {Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})}); + } + std::memcpy(input_tensors[input_ids_tensor_index_] + .GetCpuWriteView() + .buffer(), + input_ids.data(), input_ids.size() * sizeof(int32_t)); + std::memcpy(input_tensors[segment_ids_tensor_index_] + .GetCpuWriteView() + .buffer(), + segment_ids.data(), segment_ids.size() * sizeof(int32_t)); + std::memcpy(input_tensors[input_masks_tensor_index_] + .GetCpuWriteView() + .buffer(), + input_masks.data(), input_masks.size() * sizeof(int32_t)); + return input_tensors; +} + +MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto b/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto new file mode 100644 index 000000000..72b569143 --- /dev/null +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator.proto @@ -0,0 +1,29 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message BertPreprocessorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional BertPreprocessorCalculatorOptions ext = 462509271; + } + + // The maximum input sequence length for the calculator's BERT model. + optional int32 bert_max_seq_len = 1; +} diff --git a/mediapipe/calculators/tensor/bert_preprocessor_calculator_test.cc b/mediapipe/calculators/tensor/bert_preprocessor_calculator_test.cc new file mode 100644 index 000000000..b497a6168 --- /dev/null +++ b/mediapipe/calculators/tensor/bert_preprocessor_calculator_test.cc @@ -0,0 +1,154 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::testing::ElementsAreArray; + +constexpr int kNumInputTensorsForBert = 3; +constexpr int kBertMaxSeqLen = 128; +constexpr absl::string_view kTestModelPath = + "mediapipe/tasks/testdata/text/bert_text_classifier.tflite"; + +absl::StatusOr>> RunBertPreprocessorCalculator( + absl::string_view text, absl::string_view model_path) { + auto graph_config = ParseTextProtoOrDie( + absl::Substitute(R"( + input_stream: "text" + output_stream: "tensors" + node { + calculator: "BertPreprocessorCalculator" + input_stream: "TEXT:text" + input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" + output_stream: "TENSORS:tensors" + options { + [mediapipe.BertPreprocessorCalculatorOptions.ext] { + bert_max_seq_len: $0 + } + } + } + )", + kBertMaxSeqLen)); + std::vector output_packets; + tool::AddVectorSink("tensors", &graph_config, &output_packets); + + std::string model_buffer = tasks::core::LoadBinaryContent(model_path.data()); + ASSIGN_OR_RETURN(std::unique_ptr metadata_extractor, + ModelMetadataExtractor::CreateFromModelBuffer( + model_buffer.data(), model_buffer.size())); + // Run the graph. + CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize( + graph_config, + {{"metadata_extractor", + MakePacket(std::move(*metadata_extractor))}})); + MP_RETURN_IF_ERROR(graph.StartRun({})); + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + "text", MakePacket(text).At(Timestamp(0)))); + MP_RETURN_IF_ERROR(graph.WaitUntilIdle()); + + if (output_packets.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "output_packets has size $0, expected 1", output_packets.size())); + } + const std::vector& tensor_vec = + output_packets[0].Get>(); + if (tensor_vec.size() != kNumInputTensorsForBert) { + return absl::InvalidArgumentError( + absl::Substitute("tensor_vec has size $0, expected $1", + tensor_vec.size(), kNumInputTensorsForBert)); + } + + std::vector> results; + for (int i = 0; i < kNumInputTensorsForBert; i++) { + const Tensor& tensor = tensor_vec[i]; + if (tensor.element_type() != Tensor::ElementType::kInt32) { + return absl::InvalidArgumentError("Expected tensor element type kInt32"); + } + auto* buffer = tensor.GetCpuReadView().buffer(); + std::vector buffer_view(buffer, buffer + kBertMaxSeqLen); + results.push_back(buffer_view); + } + MP_RETURN_IF_ERROR(graph.CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph.WaitUntilDone()); + return results; +} + +TEST(BertPreprocessorCalculatorTest, TextClassifierWithBertModel) { + std::vector> expected_result = { + {101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 102}}; + // segment_ids + expected_result.push_back(std::vector(kBertMaxSeqLen, 0)); + // input_masks + expected_result.push_back(std::vector(expected_result[0].size(), 1)); + expected_result[2].resize(kBertMaxSeqLen); + // padding input_ids + expected_result[0].resize(kBertMaxSeqLen); + + MP_ASSERT_OK_AND_ASSIGN( + std::vector> processed_tensor_values, + RunBertPreprocessorCalculator( + "it's a charming and often affecting journey", kTestModelPath)); + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +TEST(BertPreprocessorCalculatorTest, LongInput) { + std::stringstream long_input; + long_input + << "it's a charming and often affecting journey and this is a long"; + for (int i = 0; i < kBertMaxSeqLen; ++i) { + long_input << " long"; + } + long_input << " movie review"; + std::vector> expected_result = { + {101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1998, 2023, + 2003, 1037}}; + // "long" id + expected_result[0].resize(kBertMaxSeqLen - 1, 2146); + // "[SEP]" id + expected_result[0].push_back(102); + // segment_ids + expected_result.push_back(std::vector(kBertMaxSeqLen, 0)); + // input_masks + expected_result.push_back(std::vector(kBertMaxSeqLen, 1)); + + MP_ASSERT_OK_AND_ASSIGN( + std::vector> processed_tensor_values, + RunBertPreprocessorCalculator(long_input.str(), kTestModelPath)); + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index afd260347..ec7d4afa8 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node { } ASSIGN_OR_RETURN(auto image, GetInputImage(cc)); - const Size size{image->width(), image->height()}; - RotatedRect roi = GetRoi(size.width, size.height, norm_rect); + + RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect); ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(), options_.output_tensor_height(), options_.keep_aspect_ratio(), &roi)); @@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node { } if (kOutMatrix(cc).IsConnected()) { std::array matrix; - GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height, - /*flip_horizontaly=*/false, - &matrix); + GetRotatedSubRectToRectTransformMatrix( + roi, image->width(), image->height(), + /*flip_horizontaly=*/false, &matrix); kOutMatrix(cc).Send(std::move(matrix)); } // Lazy initialization of the GPU or CPU converter. MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get())); - ASSIGN_OR_RETURN(Tensor tensor, - (image->UsesGpu() ? gpu_converter_ : cpu_converter_) - ->Convert(*image, roi, {output_width_, output_height_}, - range_min_, range_max_)); + Tensor::ElementType output_tensor_type = + GetOutputTensorType(image->UsesGpu()); + Tensor tensor(output_tensor_type, {1, output_height_, output_width_, + GetNumOutputChannels(*image)}); + MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_) + ->Convert(*image, roi, range_min_, range_max_, + /*tensor_buffer_offset=*/0, tensor)); auto result = std::make_unique>(); result->push_back(std::move(tensor)); @@ -292,15 +295,31 @@ class ImageToTensorCalculator : public Node { } } - Tensor::ElementType GetOutputTensorType() { - if (is_float_output_) { - return Tensor::ElementType::kFloat32; + Tensor::ElementType GetOutputTensorType(bool uses_gpu) { + if (!uses_gpu) { + if (is_float_output_) { + return Tensor::ElementType::kFloat32; + } + if (range_min_ < 0) { + return Tensor::ElementType::kInt8; + } else { + return Tensor::ElementType::kUInt8; + } } - if (range_min_ < 0) { - return Tensor::ElementType::kInt8; - } else { - return Tensor::ElementType::kUInt8; + // Always use float32 when GPU is enabled. + return Tensor::ElementType::kFloat32; + } + + int GetNumOutputChannels(const Image& image) { +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_METAL_ENABLED + if (image.UsesGpu()) { + return 4; } +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + // All of the processors except for Metal expect 3 channels. + return 3; } absl::StatusOr> GetInputImage( @@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node { #if !MEDIAPIPE_DISABLE_OPENCV ASSIGN_OR_RETURN( cpu_converter_, - CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType())); + CreateOpenCvConverter(cc, GetBorderMode(), + GetOutputTensorType(/*uses_gpu=*/false))); #else LOG(FATAL) << "Cannot create image to tensor opencv converter since " "MEDIAPIPE_DISABLE_OPENCV is defined."; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter.h b/mediapipe/calculators/tensor/image_to_tensor_converter.h index 39fd1ee0d..870ebc300 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter.h @@ -42,13 +42,16 @@ class ImageToTensorConverter { // @image contains image to extract from. // @roi describes region of interest within the image to extract (absolute // values). - // @output_dims dimensions of output tensor. // @range_min/max describes output tensor range image pixels should converted // to. - virtual absl::StatusOr Convert(const mediapipe::Image& input, - const RotatedRect& roi, - const Size& output_dims, - float range_min, float range_max) = 0; + // @tensor_buffer_offset an inteter representing the offset of the tensor + // buffer the result should be written to. + // @output_tensor a tensor with pre-defined shape. The "Convert" is + // responsible of populating the content into the output tensor. + virtual absl::Status Convert(const mediapipe::Image& input, + const RotatedRect& roi, float range_min, + float range_max, int tensor_buffer_offset, + Tensor& output_tensor) = 0; }; } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index ddc7ff85e..14de410ff 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -264,57 +264,58 @@ class GlProcessor : public ImageToTensorConverter { }); } - absl::StatusOr Convert(const mediapipe::Image& input, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { + absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi, + float range_min, float range_max, + int tensor_buffer_offset, + Tensor& output_tensor) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && - input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { + input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 && + input.format() != mediapipe::GpuBufferFormat::kRGB24) { return InvalidArgumentError(absl::StrCat( - "Only 4-channel texture input formats are supported, passed format: ", - static_cast(input.format()))); + "Unsupported format: ", static_cast(input.format()))); } + const auto& output_shape = output_tensor.shape(); + MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); - constexpr int kNumChannels = 3; - Tensor tensor(Tensor::ElementType::kFloat32, - {1, output_dims.height, output_dims.width, kNumChannels}); + MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext( + [this, &output_tensor, &input, &roi, &output_shape, range_min, + range_max, tensor_buffer_offset]() -> absl::Status { + const int input_num_channels = input.channels(); + auto source_texture = gl_helper_.CreateSourceTexture(input); + tflite::gpu::gl::GlTexture input_texture( + GL_TEXTURE_2D, source_texture.name(), + input_num_channels == 4 ? GL_RGB : GL_RGBA, + source_texture.width() * source_texture.height() * + input_num_channels * sizeof(uint8_t), + /*layer=*/0, + /*owned=*/false); - MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi, - &output_dims, range_min, - range_max]() -> absl::Status { - constexpr int kRgbaNumChannels = 4; - auto source_texture = gl_helper_.CreateSourceTexture(input); - tflite::gpu::gl::GlTexture input_texture( - GL_TEXTURE_2D, source_texture.name(), GL_RGBA, - source_texture.width() * source_texture.height() * kRgbaNumChannels * - sizeof(uint8_t), - /*layer=*/0, - /*owned=*/false); + constexpr float kInputImageRangeMin = 0.0f; + constexpr float kInputImageRangeMax = 1.0f; + ASSIGN_OR_RETURN(auto transform, + GetValueRangeTransformation(kInputImageRangeMin, + kInputImageRangeMax, + range_min, range_max)); - constexpr float kInputImageRangeMin = 0.0f; - constexpr float kInputImageRangeMax = 1.0f; - ASSIGN_OR_RETURN( - auto transform, - GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, - range_min, range_max)); + const int output_size = output_tensor.bytes() / output_shape.dims[0]; + auto buffer_view = output_tensor.GetOpenGlBufferWriteView(); + tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER, + buffer_view.name(), output_size, + /*offset=*/tensor_buffer_offset, + /*has_ownership=*/false); + MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer( + input_texture, + tflite::gpu::HW(source_texture.height(), source_texture.width()), + roi, + /*flip_horizontaly=*/false, transform.scale, transform.offset, + tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]), + command_queue_.get(), &output)); - auto buffer_view = tensor.GetOpenGlBufferWriteView(); - tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER, - buffer_view.name(), tensor.bytes(), - /*offset=*/0, - /*has_ownership=*/false); - MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer( - input_texture, - tflite::gpu::HW(source_texture.height(), source_texture.width()), roi, - /*flip_horizontaly=*/false, transform.scale, transform.offset, - tflite::gpu::HW(output_dims.height, output_dims.width), - command_queue_.get(), &output)); + return absl::OkStatus(); + })); - return absl::OkStatus(); - })); - - return tensor; + return absl::OkStatus(); } ~GlProcessor() override { @@ -326,6 +327,17 @@ class GlProcessor : public ImageToTensorConverter { } private: + absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { + RET_CHECK_EQ(output_shape.dims.size(), 4) + << "Wrong output dims size: " << output_shape.dims.size(); + RET_CHECK_EQ(output_shape.dims[0], 1) + << "Handling batch dimension not equal to 1 is not implemented in this " + "converter."; + RET_CHECK_EQ(output_shape.dims[3], 3) + << "Wrong output channel: " << output_shape.dims[3]; + return absl::OkStatus(); + } + std::unique_ptr command_queue_; std::unique_ptr extractor_; mediapipe::GlCalculatorHelper gl_helper_; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 6f035e67b..5efd34041 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -168,26 +168,26 @@ class GlProcessor : public ImageToTensorConverter { }); } - absl::StatusOr Convert(const mediapipe::Image& input, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { + absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi, + float range_min, float range_max, + int tensor_buffer_offset, + Tensor& output_tensor) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && - input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { + input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 && + input.format() != mediapipe::GpuBufferFormat::kRGB24) { return InvalidArgumentError(absl::StrCat( - "Only 4-channel texture input formats are supported, passed format: ", - static_cast(input.format()))); + "Unsupported format: ", static_cast(input.format()))); } + // TODO: support tensor_buffer_offset > 0 scenario. + RET_CHECK_EQ(tensor_buffer_offset, 0) + << "The non-zero tensor_buffer_offset input is not supported yet."; + const auto& output_shape = output_tensor.shape(); + MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); - constexpr int kNumChannels = 3; - Tensor tensor( - Tensor::ElementType::kFloat32, - Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels}); - - MP_RETURN_IF_ERROR( - gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims, - range_min, range_max]() -> absl::Status { + MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext( + [this, &output_tensor, &input, &roi, &output_shape, range_min, + range_max]() -> absl::Status { auto input_texture = gl_helper_.CreateSourceTexture(input); constexpr float kInputImageRangeMin = 0.0f; @@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter { GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); - auto tensor_view = tensor.GetOpenGlTexture2dWriteView(); + auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView(); MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi, /*flip_horizontaly=*/false, transform.scale, transform.offset, - output_dims, &tensor_view)); + output_shape, &tensor_view)); return absl::OkStatus(); })); - return tensor; + return absl::OkStatus(); } absl::Status ExtractSubRect(const mediapipe::GlTexture& texture, const RotatedRect& sub_rect, bool flip_horizontaly, float alpha, float beta, - const Size& output_dims, + const Tensor::Shape& output_shape, Tensor::OpenGlTexture2dView* output) { + const int output_height = output_shape.dims[1]; + const int output_width = output_shape.dims[2]; std::array transform_mat; glDisable(GL_DEPTH_TEST); glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); - glViewport(0, 0, output_dims.width, output_dims.height); + glViewport(0, 0, output_width, output_height); glActiveTexture(GL_TEXTURE0); glBindTexture(GL_TEXTURE_2D, output->name()); @@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter { } private: + absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { + RET_CHECK_EQ(output_shape.dims.size(), 4) + << "Wrong output dims size: " << output_shape.dims.size(); + RET_CHECK_EQ(output_shape.dims[0], 1) + << "Handling batch dimension not equal to 1 is not implemented in this " + "converter."; + RET_CHECK_EQ(output_shape.dims[3], 3) + << "Wrong output channel: " << output_shape.dims[3]; + return absl::OkStatus(); + } + mediapipe::GlCalculatorHelper gl_helper_; bool use_custom_zero_border_ = false; BorderMode border_mode_ = BorderMode::kReplicate; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index cfabae333..a8211d39b 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -262,7 +262,6 @@ class SubRectExtractorMetal { RET_CHECK(pipeline_state != nil); std::string output_type_def; - MTLPixelFormat pixel_format; switch (output_format) { case OutputFormat::kF16C4: output_type_def = R"( @@ -348,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter { return absl::OkStatus(); } - absl::StatusOr Convert(const mediapipe::Image& input, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { + absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi, + float range_min, float range_max, + int tensor_buffer_offset, + Tensor& output_tensor) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { @@ -359,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter { "Only 4-channel texture input formats are supported, passed format: ", static_cast(input.format()))); } + RET_CHECK_EQ(tensor_buffer_offset, 0) + << "The non-zero tensor_buffer_offset input is not supported yet."; + const auto& output_shape = output_tensor.shape(); + MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); @autoreleasepool { id texture = [metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()]; - constexpr int kNumChannels = 4; - Tensor tensor(Tensor::ElementType::kFloat32, - Tensor::Shape{1, output_dims.height, output_dims.width, - kNumChannels}); - constexpr float kInputImageRangeMin = 0.0f; constexpr float kInputImageRangeMax = 1.0f; ASSIGN_OR_RETURN( @@ -377,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter { range_min, range_max)); id command_buffer = [metal_helper_ commandBuffer]; - const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer); + const auto& buffer_view = + output_tensor.GetMtlBufferWriteView(command_buffer); MP_RETURN_IF_ERROR(extractor_->Execute( texture, roi, /*flip_horizontaly=*/false, transform.scale, transform.offset, - tflite::gpu::HW(output_dims.height, output_dims.width), + tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]), command_buffer, buffer_view.buffer())); [command_buffer commit]; - return tensor; + return absl::OkStatus(); } } private: + absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { + RET_CHECK_EQ(output_shape.dims.size(), 4) + << "Wrong output dims size: " << output_shape.dims.size(); + RET_CHECK_EQ(output_shape.dims[0], 1) + << "Handling batch dimension not equal to 1 is not implemented in this " + "converter."; + RET_CHECK_EQ(output_shape.dims[3], 4) + << "Wrong output channel: " << output_shape.dims[3]; + return absl::OkStatus(); + } + MPPMetalHelper* metal_helper_ = nil; std::unique_ptr extractor_; }; diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index 6d36e5878..f910b59f3 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter { } } - absl::StatusOr Convert(const mediapipe::Image& input, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { + absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi, + float range_min, float range_max, + int tensor_buffer_offset, + Tensor& output_tensor) override { if (input.image_format() != mediapipe::ImageFormat::SRGB && input.image_format() != mediapipe::ImageFormat::SRGBA) { return InvalidArgumentError( absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", static_cast(input.image_format()))); } - auto src = mediapipe::formats::MatView(&input); + // TODO: Remove the check once tensor_buffer_offset > 0 is + // supported. + RET_CHECK_EQ(tensor_buffer_offset, 0) + << "The non-zero tensor_buffer_offset input is not supported yet."; + const auto& output_shape = output_tensor.shape(); + MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape)); - constexpr int kNumChannels = 3; - Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height, - output_dims.width, kNumChannels}); - auto buffer_view = tensor.GetCpuWriteView(); + const int output_height = output_shape.dims[1]; + const int output_width = output_shape.dims[2]; + const int output_channels = output_shape.dims[3]; + auto buffer_view = output_tensor.GetCpuWriteView(); cv::Mat dst; switch (tensor_type_) { case Tensor::ElementType::kInt8: - dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + dst = cv::Mat(output_height, output_width, mat_type_, buffer_view.buffer()); break; case Tensor::ElementType::kFloat32: - dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + dst = cv::Mat(output_height, output_width, mat_type_, buffer_view.buffer()); break; case Tensor::ElementType::kUInt8: - dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, + dst = cv::Mat(output_height, output_width, mat_type_, buffer_view.buffer()); break; default: @@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter { cv::Mat src_points; cv::boxPoints(rotated_rect, src_points); - const float dst_width = output_dims.width; - const float dst_height = output_dims.height; + const float dst_width = output_width; + const float dst_height = output_height; /* clang-format off */ float dst_corners[8] = {0.0f, dst_height, 0.0f, 0.0f, @@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter { dst_width, dst_height}; /* clang-format on */ + auto src = mediapipe::formats::MatView(&input); cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners); cv::Mat projection_matrix = cv::getPerspectiveTransform(src_points, dst_points); @@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter { /*flags=*/cv::INTER_LINEAR, /*borderMode=*/border_mode_); - if (transformed.channels() > kNumChannels) { + if (transformed.channels() > output_channels) { cv::Mat proper_channels_mat; cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB); transformed = proper_channels_mat; @@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter { GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, range_min, range_max)); transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); - return tensor; + return absl::OkStatus(); } private: + absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) { + RET_CHECK_EQ(output_shape.dims.size(), 4) + << "Wrong output dims size: " << output_shape.dims.size(); + RET_CHECK_EQ(output_shape.dims[0], 1) + << "Handling batch dimension not equal to 1 is not implemented in this " + "converter."; + RET_CHECK_EQ(output_shape.dims[3], 3) + << "Wrong output channel: " << output_shape.dims[3]; + return absl::OkStatus(); + } + enum cv::BorderTypes border_mode_; Tensor::ElementType tensor_type_; int mat_type_; diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index 1f3768ee0..bd8eb3eed 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -26,6 +26,8 @@ #include "mediapipe/gpu/gl_calculator_helper.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe + namespace mediapipe { namespace api2 { @@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process( CalculatorContext* cc, const std::vector& input_tensors, std::vector& output_tensors) { return gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> absl::Status { + [this, cc, &input_tensors, &output_tensors]() -> absl::Status { // Explicitly copy input. for (int i = 0; i < input_tensors.size(); ++i) { glBindBuffer(GL_COPY_READ_BUFFER, @@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process( } // Run inference. - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + { + MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } output_tensors.reserve(output_size_); for (int i = 0; i < output_size_; ++i) { diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 7e11ee072..52359f7f5 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -32,6 +32,8 @@ #include "mediapipe/util/android/file/base/helpers.h" #endif // MEDIAPIPE_ANDROID +#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe + namespace mediapipe { namespace api2 { @@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl const mediapipe::InferenceCalculatorOptions::Delegate& delegate); absl::StatusOr> Process( - const std::vector& input_tensors); + CalculatorContext* cc, const std::vector& input_tensors); absl::Status Close(); @@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init( absl::StatusOr> InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( - const std::vector& input_tensors) { + CalculatorContext* cc, const std::vector& input_tensors) { std::vector output_tensors; MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> absl::Status { + [this, cc, &input_tensors, &output_tensors]() -> absl::Status { for (int i = 0; i < input_tensors.size(); ++i) { MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( input_tensors[i].GetOpenGlBufferReadView().name(), i)); @@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( output_tensors.back().GetOpenGlBufferWriteView().name(), i)); } // Run inference. - return tflite_gpu_runner_->Invoke(); + { + MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); + return tflite_gpu_runner_->Invoke(); + } })); return output_tensors; @@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) { auto output_tensors = absl::make_unique>(); ASSIGN_OR_RETURN(*output_tensors, - gpu_inference_runner_->Process(input_tensors)); + gpu_inference_runner_->Process(cc, input_tensors)); kOutTensors(cc).Send(std::move(output_tensors)); return absl::OkStatus(); diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index ff8ebe149..a85071f3e 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -224,9 +224,6 @@ absl::Status InferenceCalculatorMetalImpl::InitInterpreter( void InferenceCalculatorMetalImpl::AddDelegate( CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { - const auto& calculator_opts = - cc->Options(); - // Configure and create the delegate. TFLGpuDelegateOptions options; // `enable_quantization` enables the run of sparse models i.e. the models with diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc index 81edb34e0..1d216daf3 100644 --- a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc @@ -21,8 +21,10 @@ #include "absl/status/statusor.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/ret_check.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/string_util.h" namespace mediapipe { @@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor, std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); } +template <> +void CopyTensorBufferToInterpreter(const Tensor& input_tensor, + tflite::Interpreter* interpreter, + int input_tensor_index) { + const char* input_tensor_buffer = + input_tensor.GetCpuReadView().buffer(); + tflite::DynamicBuffer dynamic_buffer; + dynamic_buffer.AddString(input_tensor_buffer, + input_tensor.shape().num_elements()); + dynamic_buffer.WriteToTensorAsVector( + interpreter->tensor(interpreter->inputs()[input_tensor_index])); +} + template void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, int output_tensor_index, @@ -87,13 +102,13 @@ absl::StatusOr> InferenceInterpreterDelegateRunner::Run( break; } case TfLiteType::kTfLiteUInt8: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); break; } case TfLiteType::kTfLiteInt8: { - CopyTensorBufferToInterpreter(input_tensors[i], - interpreter_.get(), i); + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); break; } case TfLiteType::kTfLiteInt32: { @@ -101,6 +116,14 @@ absl::StatusOr> InferenceInterpreterDelegateRunner::Run( interpreter_.get(), i); break; } + case TfLiteType::kTfLiteString: { + CopyTensorBufferToInterpreter(input_tensors[i], + interpreter_.get(), i); + break; + } + case TfLiteType::kTfLiteBool: + // No current use-case for copying MediaPipe Tensors with bool type to + // TfLiteTensors. default: return absl::InvalidArgumentError( absl::StrCat("Unsupported input tensor type:", input_tensor_type)); @@ -146,6 +169,15 @@ absl::StatusOr> InferenceInterpreterDelegateRunner::Run( CopyTensorBufferFromInterpreter(interpreter_.get(), i, &output_tensors.back()); break; + case TfLiteType::kTfLiteBool: + output_tensors.emplace_back(Tensor::ElementType::kBool, shape, + Tensor::QuantizationParameters{1.0f, 0}); + CopyTensorBufferFromInterpreter(interpreter_.get(), i, + &output_tensors.back()); + break; + case TfLiteType::kTfLiteString: + // No current use-case for copying TfLiteTensors with string type to + // MediaPipe Tensors. default: return absl::InvalidArgumentError( absl::StrCat("Unsupported output tensor type:", diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc b/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc new file mode 100644 index 000000000..92a5f0266 --- /dev/null +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator.cc @@ -0,0 +1,174 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" +#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h" +#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace api2 { + +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +// Preprocesses input text into one int32 input tensor for a text model using +// a RegexTokenizer. +// +// Inputs: +// TEXT - std::string +// The input text. +// Side Inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the text model. Used to extract the metadata +// to construct the RegexTokenizer. +// +// Outputs: +// TENSORS - std::vector +// Vector containing a single Tensor which is the text model's input tensor. +// Depending on the tokenizer metadata, the tensor may start with +// the id of the tokenizer's token. The following tensor values will +// be the ids of the tokens of the input text. Any out-of-vocab tokens will +// have the id of the token. The tensor will be padded with the +// token id to have size equal to the max sequence length for the text +// model. +// +// Example: +// node { +// calculator: "RegexPreprocessorCalculator" +// input_stream: "TEXT:text" +// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" +// output_stream: "TENSORS:tensors" +// options { +// [mediapipe.RegexPreprocessorCalculatorOptions.ext] { +// max_seq_len: 256 +// } +// } +// } +class RegexPreprocessorCalculator : public Node { + public: + static constexpr Input kTextIn{"TEXT"}; + static constexpr SideInput kMetadataExtractorSideIn{ + "METADATA_EXTRACTOR"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + + MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut); + + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + std::unique_ptr tokenizer_; + // The max sequence length accepted by the text model. + int max_seq_len_ = 0; +}; + +absl::Status RegexPreprocessorCalculator::UpdateContract( + CalculatorContract* cc) { + const auto& options = + cc->Options(); + RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required"; + RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive"; + return absl::OkStatus(); +} + +absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) { + const ModelMetadataExtractor* metadata_extractor = + &kMetadataExtractorSideIn(cc).Get(); + const tflite::TensorMetadata* tensor_metadata = + metadata_extractor->GetInputTensorMetadata(0); + if (tensor_metadata == nullptr) { + return absl::InvalidArgumentError("No tensor metadata found"); + } + + ASSIGN_OR_RETURN( + const auto* tokenizer_metadata, + metadata_extractor->FindFirstProcessUnit( + *tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions)); + if (tokenizer_metadata == nullptr) { + return absl::InvalidArgumentError("No tokenizer metadata found"); + } + const tflite::RegexTokenizerOptions* regex_tokenizer_options = + tokenizer_metadata->options_as(); + ASSIGN_OR_RETURN(tokenizer_, + tasks::text::tokenizers::CreateRegexTokenizerFromOptions( + regex_tokenizer_options, metadata_extractor)); + + const auto& options = + cc->Options(); + max_seq_len_ = options.max_seq_len(); + return absl::OkStatus(); +} + +absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) { + tasks::text::tokenizers::TokenizerResult tokenizer_result = + tokenizer_->Tokenize(kTextIn(cc).Get()); + + int unknown_token_id = 0; + tokenizer_->GetUnknownToken(&unknown_token_id); + int pad_token_id = 0; + tokenizer_->GetPadToken(&pad_token_id); + + std::vector input_tokens(max_seq_len_, pad_token_id); + int start_token_id = 0; + int input_token_index = 0; + if (tokenizer_->GetStartToken(&start_token_id)) { + input_tokens[0] = start_token_id; + input_token_index = 1; + } + + for (int i = 0; (i < tokenizer_result.subwords.size()) && + (input_token_index < max_seq_len_); + ++i, ++input_token_index) { + const std::string& token = tokenizer_result.subwords[i]; + int token_id = 0; + if (tokenizer_->LookupId(token, &token_id)) { + input_tokens[input_token_index] = token_id; + } else { + input_tokens[input_token_index] = unknown_token_id; + } + } + + // |<-------sentence_length-------->| + // input_tensor , t1, t2... , ... + // is optional, t1, t2... will be replaced by if it's + // not found in the tokenizer vocab. + std::vector result; + result.push_back( + {Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})}); + std::memcpy(result[0].GetCpuWriteView().buffer(), + input_tokens.data(), input_tokens.size() * sizeof(int32_t)); + kTensorsOut(cc).Send(std::move(result)); + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto b/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto new file mode 100644 index 000000000..793067a80 --- /dev/null +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator.proto @@ -0,0 +1,29 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message RegexPreprocessorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional RegexPreprocessorCalculatorOptions ext = 463716697; + } + + // The maximum input sequence length for the calculator's text model. + optional int32 max_seq_len = 1; +} diff --git a/mediapipe/calculators/tensor/regex_preprocessor_calculator_test.cc b/mediapipe/calculators/tensor/regex_preprocessor_calculator_test.cc new file mode 100644 index 000000000..ef14ef035 --- /dev/null +++ b/mediapipe/calculators/tensor/regex_preprocessor_calculator_test.cc @@ -0,0 +1,130 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/sink.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::testing::ElementsAreArray; + +constexpr int kMaxSeqLen = 256; +constexpr char kTestModelPath[] = + "mediapipe/tasks/testdata/text/" + "test_model_text_classifier_with_regex_tokenizer.tflite"; + +absl::StatusOr> RunRegexPreprocessorCalculator( + absl::string_view text) { + auto graph_config = + ParseTextProtoOrDie(absl::Substitute( + R"pb( + input_stream: "text" + output_stream: "tensors" + node { + calculator: "RegexPreprocessorCalculator" + input_stream: "TEXT:text" + input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" + output_stream: "TENSORS:tensors" + options { + [mediapipe.RegexPreprocessorCalculatorOptions.ext] { + max_seq_len: $0 + } + } + } + )pb", + kMaxSeqLen)); + std::vector output_packets; + tool::AddVectorSink("tensors", &graph_config, &output_packets); + + std::string model_buffer = tasks::core::LoadBinaryContent(kTestModelPath); + ASSIGN_OR_RETURN(std::unique_ptr metadata_extractor, + ModelMetadataExtractor::CreateFromModelBuffer( + model_buffer.data(), model_buffer.size())); + // Run the graph. + CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize( + graph_config, + {{"metadata_extractor", + MakePacket(std::move(*metadata_extractor))}})); + MP_RETURN_IF_ERROR(graph.StartRun({})); + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + "text", MakePacket(text).At(Timestamp(0)))); + MP_RETURN_IF_ERROR(graph.WaitUntilIdle()); + + if (output_packets.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "output_packets has size $0, expected 1", output_packets.size())); + } + const std::vector& tensor_vec = + output_packets[0].Get>(); + if (tensor_vec.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "tensor_vec has size $0, expected $1", tensor_vec.size(), 1)); + } + if (tensor_vec[0].element_type() != Tensor::ElementType::kInt32) { + return absl::InvalidArgumentError("Expected tensor element type kInt32"); + } + auto* buffer = tensor_vec[0].GetCpuReadView().buffer(); + std::vector result(buffer, buffer + kMaxSeqLen); + MP_RETURN_IF_ERROR(graph.CloseAllPacketSources()); + MP_RETURN_IF_ERROR(graph.WaitUntilDone()); + return result; +} + +TEST(RegexPreprocessorCalculatorTest, TextClassifierModel) { + MP_ASSERT_OK_AND_ASSIGN( + std::vector processed_tensor_values, + RunRegexPreprocessorCalculator("This is the best movie I’ve seen in " + "recent years. Strongly recommend it!")); + static const int expected_result[kMaxSeqLen] = { + 1, 2, 9, 4, 118, 20, 2, 2, 110, 11, 1136, 153, 2, 386, 12}; + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +TEST(RegexPreprocessorCalculatorTest, LongInput) { + std::stringstream long_input; + long_input << "This is the best"; + for (int i = 0; i < kMaxSeqLen; ++i) { + long_input << " best"; + } + long_input << "movie I’ve seen in recent years. Strongly recommend it!"; + MP_ASSERT_OK_AND_ASSIGN(std::vector processed_tensor_values, + RunRegexPreprocessorCalculator(long_input.str())); + std::vector expected_result = {1, 2, 9, 4, 118}; + // "best" id + expected_result.resize(kMaxSeqLen, 118); + EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index f3c7c7b09..0b750b859 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -296,7 +296,6 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { output_tensors->emplace_back(Tensor::ElementType::kFloat32, Tensor::Shape{1, height, width, channels}); #if MEDIAPIPE_METAL_ENABLED - id device = gpu_helper_.mtlDevice; id command_buffer = [gpu_helper_ commandBuffer]; command_buffer.label = @"TensorConverterCalculatorConvert"; id compute_encoder = diff --git a/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc b/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc index 0b7e6f082..3d364a53c 100644 --- a/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_dequantization_calculator.cc @@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) { case Tensor::ElementType::kInt8: Dequantize(input_tensor, &output_tensors->back()); break; + case Tensor::ElementType::kBool: + Dequantize(input_tensor, &output_tensors->back()); + break; default: return absl::InvalidArgumentError(absl::StrCat( "Unsupported input tensor type: ", input_tensor.element_type())); diff --git a/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc b/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc index fd41cc763..e0d549123 100644 --- a/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_dequantization_calculator_test.cc @@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) { ValidateResult(GetOutput(), {-1.007874, 0, 1}); } +TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) { + std::vector tensor = {true, false, true}; + PushTensor(Tensor::ElementType::kBool, tensor, + Tensor::QuantizationParameters{1.0f, 0}); + + MP_ASSERT_OK(runner_.Run()); + + ValidateResult(GetOutput(), {1, 0, 1}); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index 5bfc00ed7..76d2869e8 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { } absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { + const auto& options = cc->Options(); const auto& input_tensors = *kInTensors(cc); RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); @@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { auto raw_scores = view.buffer(); auto classification_list = absl::make_unique(); + if (options.has_tensor_index()) { + classification_list->set_tensor_index(options.tensor_index()); + } + if (options.has_tensor_name()) { + classification_list->set_tensor_name(options.tensor_name()); + } if (is_binary_classification_) { Classification* class_first = classification_list->add_classification(); Classification* class_second = classification_list->add_classification(); diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto index 32bc4b63a..f0f7727ba 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.proto @@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions { // that are not in the `allow_classes` field will be completely ignored. // `ignore_classes` and `allow_classes` are mutually exclusive. repeated int32 allow_classes = 8 [packed = true]; + + // The optional index of the tensor these classifications originate from. + optional int32 tensor_index = 10; + // The optional name of the tensor these classifications originate from. + optional string tensor_name = 11; } diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc index 9634635f0..b20f2768c 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator_test.cc @@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest, } } +TEST_F(TensorsToClassificationCalculatorTest, + CorrectOutputWithTensorNameAndIndex) { + mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "TensorsToClassificationCalculator" + input_stream: "TENSORS:tensors" + output_stream: "CLASSIFICATIONS:classifications" + options { + [mediapipe.TensorsToClassificationCalculatorOptions.ext] { + tensor_index: 1 + tensor_name: "foo" + } + } + )pb")); + + BuildGraph(&runner, {0, 0.5, 1}); + MP_ASSERT_OK(runner.Run()); + + const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets; + + EXPECT_EQ(1, output_packets_.size()); + + const auto& classification_list = + output_packets_[0].Get(); + EXPECT_EQ(3, classification_list.classification_size()); + + // Verify that the tensor_index and tensor_name fields are correctly set. + EXPECT_EQ(classification_list.tensor_index(), 1); + EXPECT_EQ(classification_list.tensor_name(), "foo"); +} + TEST_F(TensorsToClassificationCalculatorTest, ClassNameAllowlistWithLabelItems) { mediapipe::CalculatorRunner runner(ParseTextProtoOrDie(R"pb( diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 11c1341d4..97ef01b4c 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -532,7 +532,6 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU( detection_classes.data(), output_detections)); #elif MEDIAPIPE_METAL_ENABLED - id device = gpu_helper_.mtlDevice; if (!anchors_init_) { if (input_tensors.size() == kNumInputTensorsWithAnchors) { RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); diff --git a/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc index 51c3a9a09..5c1c70aa4 100644 --- a/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/text_to_tensor_calculator_test.cc @@ -67,9 +67,7 @@ absl::StatusOr RunTextToTensorCalculator(absl::string_view text) { "tensor_vec has size $0, expected 1", tensor_vec.size())); } if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { - return absl::InvalidArgumentError(absl::Substitute( - "tensor has element type $0, expected $1", tensor_vec[0].element_type(), - Tensor::ElementType::kChar)); + return absl::InvalidArgumentError("Expected tensor element type kChar"); } const char* buffer = tensor_vec[0].GetCpuReadView().buffer(); return std::string(buffer, text.length()); diff --git a/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator.cc b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator.cc new file mode 100644 index 000000000..e589289f6 --- /dev/null +++ b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator.cc @@ -0,0 +1,167 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace api2 { + +using ::mediapipe::tasks::core::FindTensorIndexByMetadataName; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; + +constexpr absl::string_view kQueryTextMetadataName = "inp_text"; +constexpr absl::string_view kResponseContextMetadataName = "res_context"; +constexpr absl::string_view kResponseTextMetadataName = "res_text"; + +constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3; + +// Preprocesses input text into three kTfLiteString input tensors for a +// Universal Sentence Encoder (USE) model. +// +// The associated USE model is expected to contain input tensors with metadata +// names: +// +// Tensor | Metadata Name +// ---------------- | ------------------ +// Query text | "inp_text" +// Response context | "res_context" +// Response text | "res_text" +// +// This calculator will return an error if the model does not have three input +// tensors or if the tensors do not have metadata names corresponding to the +// above names in some order. Additional details regarding these input +// tensors are given in the Calculator "Outputs" section below. +// +// Inputs: +// TEXT - std::string +// The text to be embedded. +// Side Inputs: +// METADATA_EXTRACTOR - ModelMetadataExtractor +// The metadata extractor for the USE model. Used to determine the order of +// the three input Tensors for the USE model. +// +// Outputs: +// TENSORS - std::vector +// Vector containing the three input Tensors for the USE model. The tensors +// fit a question-answering setting and store a query text, a response +// context, and a response text. This calculator will just be preprocessing +// a single input text that will be stored in the response text tensor. The +// query text and response context tensors will store empty strings. +// +// Example: +// node { +// calculator: "UniversalSentenceEncoderPreprocessorCalculator" +// input_stream: "TEXT:text" +// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" +// output_stream: "TENSORS:tensors" +// } +class UniversalSentenceEncoderPreprocessorCalculator : public Node { + public: + static constexpr Input kTextIn{"TEXT"}; + static constexpr SideInput kMetadataExtractorSideIn{ + "METADATA_EXTRACTOR"}; + static constexpr Output> kTensorsOut{"TENSORS"}; + + MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + // Indices of the three input tensors for the USE model. They should form the + // set {0, 1, 2}. + int query_text_tensor_index_ = 0; + int response_context_tensor_index_ = 1; + int response_text_tensor_index_ = 2; + + // Tensor shapes for the model's input tensors. + // The query text and response context tensors will only hold the empty + // string, so their tensors will have shape [0], but the Universal Sentence + // Encoder model's input signature requires them to be present. The response + // text tensor will store the embedding text and have shape + // [embedding_text_len]. + std::array tensor_shapes_; +}; + +absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open( + CalculatorContext* cc) { + const ModelMetadataExtractor* metadata_extractor = + &kMetadataExtractorSideIn(cc).Get(); + auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata(); + query_text_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kQueryTextMetadataName); + response_context_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kResponseContextMetadataName); + response_text_tensor_index_ = FindTensorIndexByMetadataName( + input_tensors_metadata, kResponseTextMetadataName); + + absl::flat_hash_set tensor_indices = absl::flat_hash_set( + {query_text_tensor_index_, response_context_tensor_index_, + response_text_tensor_index_}); + if (tensor_indices != absl::flat_hash_set({0, 1, 2})) { + return absl::InvalidArgumentError(absl::Substitute( + "Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}", + query_text_tensor_index_, response_context_tensor_index_, + response_text_tensor_index_)); + } + return absl::OkStatus(); +} + +absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process( + CalculatorContext* cc) { + absl::string_view text = kTextIn(cc).Get(); + const int text_len = static_cast(text.length()); + tensor_shapes_[response_text_tensor_index_] = text_len; + + std::vector input_tensors; + input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder); + for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) { + input_tensors.push_back( + {Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})}); + } + + std::memcpy( + input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer(), + "", 0); + std::memcpy(input_tensors[response_context_tensor_index_] + .GetCpuWriteView() + .buffer(), + "", 0); + std::memcpy(input_tensors[response_text_tensor_index_] + .GetCpuWriteView() + .buffer(), + text.data(), text_len * sizeof(char)); + kTensorsOut(cc).Send(std::move(input_tensors)); + return absl::OkStatus(); +} + +MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator); + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc new file mode 100644 index 000000000..0f4744c90 --- /dev/null +++ b/mediapipe/calculators/tensor/universal_sentence_encoder_preprocessor_calculator_test.cc @@ -0,0 +1,109 @@ +// Copyright 2022 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/options_map.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace { + +using ::mediapipe::IsOkAndHolds; +using ::mediapipe::tasks::metadata::ModelMetadataExtractor; +using ::testing::ElementsAreArray; + +constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3; + +constexpr absl::string_view kTestModelPath = + "mediapipe/tasks/testdata/text/" + "universal_sentence_encoder_qa_with_metadata.tflite"; + +absl::StatusOr> +RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) { + auto graph_config = ParseTextProtoOrDie(R"pb( + input_stream: "text" + output_stream: "tensors" + node { + calculator: "UniversalSentenceEncoderPreprocessorCalculator" + input_stream: "TEXT:text" + input_side_packet: "METADATA_EXTRACTOR:metadata_extractor" + output_stream: "TENSORS:tensors" + } + )pb"); + std::vector output_packets; + tool::AddVectorSink("tensors", &graph_config, &output_packets); + + std::string model_buffer = + tasks::core::LoadBinaryContent(kTestModelPath.data()); + ASSIGN_OR_RETURN(std::unique_ptr metadata_extractor, + ModelMetadataExtractor::CreateFromModelBuffer( + model_buffer.data(), model_buffer.size())); + // Run the graph. + CalculatorGraph graph; + MP_RETURN_IF_ERROR(graph.Initialize( + graph_config, + {{"metadata_extractor", + MakePacket(std::move(*metadata_extractor))}})); + MP_RETURN_IF_ERROR(graph.StartRun({})); + MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( + "text", MakePacket(text).At(Timestamp(0)))); + MP_RETURN_IF_ERROR(graph.WaitUntilIdle()); + + if (output_packets.size() != 1) { + return absl::InvalidArgumentError(absl::Substitute( + "output_packets has size $0, expected 1", output_packets.size())); + } + + const std::vector& tensor_vec = + output_packets[0].Get>(); + if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) { + return absl::InvalidArgumentError(absl::Substitute( + "tensor_vec has size $0, expected $1", tensor_vec.size(), + kNumInputTensorsForUniversalSentenceEncoder)); + } + if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { + return absl::InvalidArgumentError("Expected tensor element type kChar"); + } + std::vector results; + for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) { + results.push_back( + {tensor_vec[i].GetCpuReadView().buffer(), + static_cast(tensor_vec[i].shape().num_elements())}); + } + return results; +} + +TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) { + ASSERT_THAT( + RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"), + IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"}))); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 1d2f279aa..2007a4fe1 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -331,6 +331,7 @@ cc_library( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/status", "@org_tensorflow//tensorflow/lite:framework", ], alwayslink = 1, diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index d9dfd1526..f2a2f68b6 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -499,7 +499,6 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { gpu_data_out_ = absl::make_unique(); gpu_data_out_->elements = input.height() * input.width() * max_num_channels_; const bool include_alpha = (max_num_channels_ == 4); - const bool single_channel = (max_num_channels_ == 1); if (!(format == mediapipe::ImageFormat::GRAY8 || format == mediapipe::ImageFormat::SRGB || format == mediapipe::ImageFormat::SRGBA)) @@ -509,6 +508,7 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED #if MEDIAPIPE_TFLITE_GL_INFERENCE + const bool single_channel = (max_num_channels_ == 1); MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( [this, &include_alpha, &input, &single_channel]() -> absl::Status { // Device memory. diff --git a/mediapipe/calculators/tflite/tflite_model_calculator.cc b/mediapipe/calculators/tflite/tflite_model_calculator.cc index d118e878c..891a9f731 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator.cc @@ -16,6 +16,7 @@ #include #include +#include "absl/status/status.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/ret_check.h" @@ -81,6 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase { } if (cc->InputSidePackets().HasTag("MODEL_FD")) { +#ifdef ABSL_HAVE_MMAP model_packet = cc->InputSidePackets().Tag("MODEL_FD"); const auto& model_fd = model_packet.Get>(); @@ -89,6 +91,10 @@ class TfLiteModelCalculator : public CalculatorBase { tflite::DefaultErrorReporter()); model = tflite::FlatBufferModel::BuildFromAllocation( std::move(model_allocation), tflite::DefaultErrorReporter()); +#else + return absl::FailedPreconditionError( + "Loading by file descriptor is not supported on this platform."); +#endif } RET_CHECK(model) << "Failed to load TfLite model from blob."; diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 2ed158f89..3a9ddc36f 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -143,9 +143,7 @@ mediapipe_proto_library( cc_library( name = "packet_frequency_calculator", srcs = ["packet_frequency_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto", @@ -190,9 +188,7 @@ cc_test( cc_library( name = "packet_latency_calculator", srcs = ["packet_latency_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto", diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index 4aab3b676..dcd76d47b 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { text->set_left(label_left_px_); text->set_baseline(label_baseline_px + i * label_height_px_); text->set_font_face(options_.font_face()); + if (options_.outline_thickness() > 0) { + text->set_outline_thickness(options_.outline_thickness()); + if (options_.outline_color_size() > 0) { + *(text->mutable_outline_color()) = + options_.outline_color(i % options_.outline_color_size()); + } else { + text->mutable_outline_color()->set_r(0); + text->mutable_outline_color()->set_g(0); + text->mutable_outline_color()->set_b(0); + } + } } cc->Outputs() .Tag(kRenderDataTag) diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.proto b/mediapipe/calculators/util/labels_to_render_data_calculator.proto index cf0ada9c2..7946ff683 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.proto @@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions { // Thickness for drawing the label(s). optional double thickness = 2 [default = 2]; + // Color of outline around each character, if any. One per label, as with + // color attribute. + repeated Color outline_color = 12; + + // Thickness of outline around each character. + optional double outline_thickness = 11; + // The font height in absolute pixels. optional int32 font_height_px = 3 [default = 50]; diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java index 1b733ed82..10e6422ba 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/GIFEditText.java @@ -18,7 +18,7 @@ import android.content.ClipDescription; import android.content.Context; import android.net.Uri; import android.os.Bundle; -import androidx.appcompat.widget.AppCompatEditText; +import android.support.v7.widget.AppCompatEditText; import android.util.AttributeSet; import android.util.Log; import android.view.inputmethod.EditorInfo; diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index f6487a17a..19c51853c 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -1685,10 +1685,3 @@ cc_test( "@com_google_absl//absl/strings:str_format", ], ) - -# Expose the proto source files for building mediapipe AAR. -filegroup( - name = "protos_src", - srcs = glob(["*.proto"]), - visibility = ["//mediapipe:__subpackages__"], -) diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index 6de444438..76aace6f5 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -14,15 +14,10 @@ cc_library( name = "builder", hdrs = ["builder.h"], deps = [ - ":const_str", - ":contract", - ":node", - ":packet", ":port", "//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_contract", "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index b78014155..82905d2f5 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -5,12 +5,7 @@ #include #include "absl/container/btree_map.h" -#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" -#include "mediapipe/framework/api2/const_str.h" -#include "mediapipe/framework/api2/contract.h" -#include "mediapipe/framework/api2/node.h" -#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_contract.h" @@ -112,6 +107,17 @@ class MultiPort : public Single { std::vector>& vec_; }; +namespace internal_builder { + +template +using AllowCast = std::integral_constant && + !std::is_same_v>; + +} // namespace internal_builder + +template +class SourceImpl; + // These classes wrap references to the underlying source/destination // endpoints, adding type information and the user-visible API. template @@ -122,16 +128,21 @@ class DestinationImpl { explicit DestinationImpl(std::vector>* vec) : DestinationImpl(&GetWithAutoGrow(vec, 0)) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {} + + template {}, int> = 0> + DestinationImpl Cast() { + return DestinationImpl(&base_); + } + + private: DestinationBase& base_; + + template + friend class SourceImpl; }; template -class MultiDestinationImpl : public MultiPort> { - public: - using MultiPort>::MultiPort; -}; - -template class SourceImpl { public: using Base = SourceBase; @@ -171,12 +182,8 @@ class SourceImpl { return AddTarget(dest); } - template - struct AllowCast - : public std::integral_constant && - !std::is_same_v> {}; - - template {}, int> = 0> + template {}, int> = 0> SourceImpl Cast() { return SourceImpl(base_); } @@ -186,12 +193,6 @@ class SourceImpl { SourceBase* base_; }; -template -class MultiSourceImpl : public MultiPort> { - public: - using MultiPort>::MultiPort; -}; - // A source and a destination correspond to an output/input stream on a node, // and a side source and side destination correspond to an output/input side // packet. @@ -201,20 +202,20 @@ class MultiSourceImpl : public MultiPort> { template using Source = SourceImpl; template -using MultiSource = MultiSourceImpl; +using MultiSource = MultiPort>; template using SideSource = SourceImpl; template -using MultiSideSource = MultiSourceImpl; +using MultiSideSource = MultiPort>; template using Destination = DestinationImpl; template using SideDestination = DestinationImpl; template -using MultiDestination = MultiDestinationImpl; +using MultiDestination = MultiPort>; template -using MultiSideDestination = MultiDestinationImpl; +using MultiSideDestination = MultiPort>; class NodeBase { public: @@ -439,8 +440,9 @@ class Graph { // Creates a node of a specific type. Should be used for pure interfaces, // which do not have a built-in type string. template - Node& AddNode(const std::string& type) { - auto node = std::make_unique>(type); + Node& AddNode(absl::string_view type) { + auto node = + std::make_unique>(std::string(type.data(), type.size())); auto node_p = node.get(); nodes_.emplace_back(std::move(node)); return *node_p; @@ -448,16 +450,18 @@ class Graph { // Creates a generic node, with no compile-time checking of inputs and // outputs. This can be used for calculators whose contract is not visible. - GenericNode& AddNode(const std::string& type) { - auto node = std::make_unique(type); + GenericNode& AddNode(absl::string_view type) { + auto node = + std::make_unique(std::string(type.data(), type.size())); auto node_p = node.get(); nodes_.emplace_back(std::move(node)); return *node_p; } // For legacy PacketGenerators. - PacketGenerator& AddPacketGenerator(const std::string& type) { - auto node = std::make_unique(type); + PacketGenerator& AddPacketGenerator(absl::string_view type) { + auto node = std::make_unique( + std::string(type.data(), type.size())); auto node_p = node.get(); packet_gens_.emplace_back(std::move(node)); return *node_p; diff --git a/mediapipe/framework/api2/builder_test.cc b/mediapipe/framework/api2/builder_test.cc index 3244e092d..810c52527 100644 --- a/mediapipe/framework/api2/builder_test.cc +++ b/mediapipe/framework/api2/builder_test.cc @@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) { node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast(); any_type_output.SetName("any_type_output"); + any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast(); + CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie(R"pb( node { @@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) { output_stream: "ANY_OUTPUT:any_type_output" } input_stream: "GRAPH_ANY_INPUT:__stream_0" + output_stream: "GRAPH_ANY_OUTPUT:any_type_output" )pb"); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); } diff --git a/mediapipe/framework/calculator_base.h b/mediapipe/framework/calculator_base.h index f9f0d7a8a..19f37f9de 100644 --- a/mediapipe/framework/calculator_base.h +++ b/mediapipe/framework/calculator_base.h @@ -185,7 +185,7 @@ class CalculatorBaseFactory { // Functions for checking that the calculator has the required GetContract. template constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { - typedef absl::Status (*GetContractType)(CalculatorContract * cc); + typedef absl::Status (*GetContractType)(CalculatorContract* cc); return std::is_same::value; } template diff --git a/mediapipe/framework/calculator_profile.proto b/mediapipe/framework/calculator_profile.proto index 06ec678a9..1512da6af 100644 --- a/mediapipe/framework/calculator_profile.proto +++ b/mediapipe/framework/calculator_profile.proto @@ -133,7 +133,12 @@ message GraphTrace { TPU_TASK = 13; GPU_CALIBRATION = 14; PACKET_QUEUED = 15; + GPU_TASK_INVOKE = 16; + TPU_TASK_INVOKE = 17; } + // //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags, + // //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list, + // ) // The timing for one packet set being processed at one caclulator node. message CalculatorTrace { diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index 4a509ab69..b967b27fb 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -334,13 +334,6 @@ mediapipe_register_type( deps = [":landmark_cc_proto"], ) -# Expose the proto source files for building mediapipe AAR. -filegroup( - name = "protos_src", - srcs = glob(["*.proto"]), - visibility = ["//mediapipe:__subpackages__"], -) - cc_library( name = "image", srcs = ["image.cc"], @@ -469,6 +462,10 @@ cc_library( ], "//conditions:default": [], }), + defines = select({ + "//mediapipe/framework:android_no_jni": ["MEDIAPIPE_NO_JNI"], + "//conditions:default": [], + }), linkopts = select({ "//mediapipe:ios": [ "-framework CoreVideo", diff --git a/mediapipe/framework/formats/annotation/BUILD b/mediapipe/framework/formats/annotation/BUILD index 2e33f7668..328001e85 100644 --- a/mediapipe/framework/formats/annotation/BUILD +++ b/mediapipe/framework/formats/annotation/BUILD @@ -33,10 +33,3 @@ mediapipe_proto_library( srcs = ["rasterization.proto"], visibility = ["//visibility:public"], ) - -# Expose the proto source files for building mediapipe AAR. -filegroup( - name = "protos_src", - srcs = glob(["*.proto"]), - visibility = ["//mediapipe:__subpackages__"], -) diff --git a/mediapipe/framework/formats/classification.proto b/mediapipe/framework/formats/classification.proto index 7efd9074d..c3eea07ff 100644 --- a/mediapipe/framework/formats/classification.proto +++ b/mediapipe/framework/formats/classification.proto @@ -37,6 +37,10 @@ message Classification { // Group of Classification protos. message ClassificationList { repeated Classification classification = 1; + // Optional index of the tensor that produced these classifications. + optional int32 tensor_index = 2; + // Optional name of the tensor that produced these classifications. + optional string tensor_name = 3; } // Group of ClassificationList protos. diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index 2f2859837..ff9da3ec6 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -31,11 +31,12 @@ #if MEDIAPIPE_METAL_ENABLED #import #endif // MEDIAPIPE_METAL_ENABLED - +#ifndef MEDIAPIPE_NO_JNI #if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) #define MEDIAPIPE_TENSOR_USE_AHWB 1 #endif // __ANDROID_API__ >= 26 || // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) +#endif // MEDIAPIPE_NO_JNI #ifdef MEDIAPIPE_TENSOR_USE_AHWB #include @@ -43,7 +44,6 @@ #include "third_party/GL/gl/include/EGL/egl.h" #include "third_party/GL/gl/include/EGL/eglext.h" #endif // MEDIAPIPE_TENSOR_USE_AHWB - #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_context.h" @@ -97,8 +97,8 @@ class Tensor { kUInt8, kInt8, kInt32, - // TODO: Update the inference runner to handle kTfLiteString. - kChar + kChar, + kBool }; struct Shape { Shape() = default; @@ -330,6 +330,8 @@ class Tensor { return sizeof(int32_t); case ElementType::kChar: return sizeof(char); + case ElementType::kBool: + return sizeof(bool); } } int bytes() const { return shape_.num_elements() * element_size(); } diff --git a/mediapipe/framework/formats/tensor_ahwb.cc b/mediapipe/framework/formats/tensor_ahwb.cc index c839cf5a2..b11f6b55b 100644 --- a/mediapipe/framework/formats/tensor_ahwb.cc +++ b/mediapipe/framework/formats/tensor_ahwb.cc @@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const { if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) { // EGLSync is failed. Use another synchronization method. // TODO: Use tflite::gpu::GlBufferSync and GlActiveSync. - glFinish(); + gl_context_->Run([]() { glFinish(); }); } else if (valid_ & kValidAHardwareBuffer) { CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the " "completion function to be set"; diff --git a/mediapipe/framework/formats/tensor_test.cc b/mediapipe/framework/formats/tensor_test.cc index fe702f66b..44468cb8f 100644 --- a/mediapipe/framework/formats/tensor_test.cc +++ b/mediapipe/framework/formats/tensor_test.cc @@ -29,6 +29,9 @@ TEST(General, TestDataTypes) { Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4}); EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char)); + + Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3}); + EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool)); } TEST(Cpu, TestMemoryAllocation) { diff --git a/mediapipe/framework/profiler/trace_buffer.h b/mediapipe/framework/profiler/trace_buffer.h index 069f09610..60352c705 100644 --- a/mediapipe/framework/profiler/trace_buffer.h +++ b/mediapipe/framework/profiler/trace_buffer.h @@ -109,6 +109,11 @@ struct TraceEvent { static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK; static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION; static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED; + static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE; + static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE; + // //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags, + // //depot/mediapipe/framework/calculator_profile.proto:event_type, + // ) }; // Packet trace log buffer. diff --git a/mediapipe/framework/profiler/trace_builder.cc b/mediapipe/framework/profiler/trace_builder.cc index 10ce879ff..ce5bf1e25 100644 --- a/mediapipe/framework/profiler/trace_builder.cc +++ b/mediapipe/framework/profiler/trace_builder.cc @@ -64,7 +64,7 @@ void BasicTraceEventTypes(TraceEventRegistry* result) { std::vector basic_types = { {TraceEvent::UNKNOWN, "An uninitialized trace-event."}, {TraceEvent::OPEN, "A call to Calculator::Open.", true, true}, - {TraceEvent::PROCESS, "A call to Calculator::Open.", true, true}, + {TraceEvent::PROCESS, "A call to Calculator::Process.", true, true}, {TraceEvent::CLOSE, "A call to Calculator::Close.", true, true}, {TraceEvent::NOT_READY, "A calculator cannot process packets yet."}, diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 106738a49..e54fb2177 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -150,7 +150,7 @@ cc_library( name = "executor_util", srcs = ["executor_util.cc"], hdrs = ["executor_util.h"], - visibility = ["//mediapipe/framework:mediapipe_internal"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto", diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 9b5de0235..aec2445b9 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -378,8 +378,11 @@ cc_library( ], }), deps = [ + ":gl_texture_buffer", ":gpu_buffer_format", ":gpu_buffer_storage", + ":image_frame_view", + "//mediapipe/framework/formats:image_frame", "@com_google_absl//absl/strings:str_format", ], ) @@ -1050,7 +1053,7 @@ objc_library( alwayslink = 1, ) -MIN_IOS_VERSION = "9.0" # For thread_local. +MIN_IOS_VERSION = "11.0" test_suite( name = "ios", diff --git a/mediapipe/gpu/MPPGraphGPUData.mm b/mediapipe/gpu/MPPGraphGPUData.mm index 001d4e888..8ac1eefa5 100644 --- a/mediapipe/gpu/MPPGraphGPUData.mm +++ b/mediapipe/gpu/MPPGraphGPUData.mm @@ -111,7 +111,8 @@ typedef CVOpenGLESTextureCacheRef CVTextureCacheType; - (CVMetalTextureCacheRef)mtlTextureCache { @synchronized(self) { if (!_mtlTextureCache) { - CVReturn err = CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); + CVReturn __unused err = + CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); // TODO: register and flush metal caches too. } diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 75eeeb936..78b196b08 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -47,6 +47,8 @@ static void EglThreadExitCallback(void* key_value) { // implementations, and should be considered as an undocumented vendor // extension. // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml + // + // NOTE: crashes on some Android devices (occurs with libGLES_meow.so). eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, EGL_NO_SURFACE, EGL_NO_CONTEXT); #endif diff --git a/mediapipe/gpu/gl_context_nsgl.cc b/mediapipe/gpu/gl_context_nsgl.cc index dda74f0ce..561474ad8 100644 --- a/mediapipe/gpu/gl_context_nsgl.cc +++ b/mediapipe/gpu/gl_context_nsgl.cc @@ -78,7 +78,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) { 16, 0}; - pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs]; + pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_2_1]; } if (!pixel_format_) { // On several Forge machines, the default config fails. For now let's do diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 69d2fab7a..fbb91a8f5 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -144,14 +144,23 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { context](std::shared_ptr sync_token) { CHECK_NE(name_, 0); GLuint name_to_delete = name_; - context->RunWithoutWaiting([name_to_delete, sync_token]() { - if (sync_token) { - // TODO: maybe we do not actually have to wait for the - // consumer sync here. Check docs. - sync_token->WaitOnGpu(); - } else { - LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback"; - } + context->RunWithoutWaiting([name_to_delete]() { + // Note that we do not wait for consumers to be done before deleting the + // texture. Based on a reading of the GLES 3.0 spec, appendix D: + // - when a texture is deleted, it is _not_ automatically unbound from + // bind points in other contexts; + // - when a texture is deleted, its name becomes immediately invalid, but + // the actual object is not deleted until it is no longer in use, i.e. + // attached to a container object or bound to a context; + // - deleting an object is not an operation that changes its contents; + // - within each context, commands are executed sequentially, so it seems + // like an unbind that follows a command that reads a texture should not + // take effect until the GPU has actually finished executing the + // previous commands. + // The final point is the least explicit in the docs, but it is implied by + // normal single-context behavior. E.g. if you do bind, delete, render, + // unbind, the object is not deleted until the unbind, and it waits for + // the render to finish. DLOG_IF(ERROR, !glIsTexture(name_to_delete)) << "Deleting invalid texture id: " << name_to_delete; glDeleteTextures(1, &name_to_delete); @@ -185,7 +194,10 @@ void GlTextureBuffer::Updated(std::shared_ptr prod_token) { << "Updated existing texture which had not been marked for reuse!"; CHECK(prod_token); producer_sync_ = std::move(prod_token); - producer_context_ = producer_sync_->GetContext(); + const auto& synced_context = producer_sync_->GetContext(); + if (synced_context) { + producer_context_ = synced_context; + } } void GlTextureBuffer::DidRead(std::shared_ptr cons_token) const { diff --git a/mediapipe/gpu/gl_texture_view.h b/mediapipe/gpu/gl_texture_view.h index 1f0a23f31..8b47d620b 100644 --- a/mediapipe/gpu/gl_texture_view.h +++ b/mediapipe/gpu/gl_texture_view.h @@ -65,6 +65,7 @@ class GlTextureView { friend class GpuBuffer; friend class GlTextureBuffer; friend class GpuBufferStorageCvPixelBuffer; + friend class GpuBufferStorageAhwb; GlTextureView(GlContext* context, GLenum target, GLuint name, int width, int height, std::shared_ptr gpu_buffer, int plane, DetachFn detach, DoneWritingFn done_writing) diff --git a/mediapipe/gpu/gpu_buffer_test.cc b/mediapipe/gpu/gpu_buffer_test.cc index c207acf60..3fd519b21 100644 --- a/mediapipe/gpu/gpu_buffer_test.cc +++ b/mediapipe/gpu/gpu_buffer_test.cc @@ -18,6 +18,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/tool/test_util.h" +#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_test_base.h" #include "stb_image.h" diff --git a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java index 4af9dae78..e3a878f91 100644 --- a/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java +++ b/mediapipe/java/com/google/mediapipe/framework/AndroidPacketCreator.java @@ -17,8 +17,8 @@ package com.google.mediapipe.framework; import android.graphics.Bitmap; import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor; -import com.google.mediapipe.framework.image.Image; -import com.google.mediapipe.framework.image.ImageProperties; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.framework.image.MPImageProperties; import java.nio.ByteBuffer; // TODO: use Preconditions in this file. @@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator { } /** - * Creates an Image packet from an {@link Image}. + * Creates a MediaPipe Image packet from a {@link MPImage}. * *

The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. */ - public Packet createImage(Image image) { + public Packet createImage(MPImage image) { // TODO: Choose the best storage from multiple containers. - ImageProperties properties = image.getContainedImageProperties().get(0); - if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { + MPImageProperties properties = image.getContainedImageProperties().get(0); + if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) { ByteBuffer buffer = ByteBufferExtractor.extract(image); int numChannels = 0; switch (properties.getImageFormat()) { - case Image.IMAGE_FORMAT_RGBA: + case MPImage.IMAGE_FORMAT_RGBA: numChannels = 4; break; - case Image.IMAGE_FORMAT_RGB: + case MPImage.IMAGE_FORMAT_RGB: numChannels = 3; break; - case Image.IMAGE_FORMAT_ALPHA: + case MPImage.IMAGE_FORMAT_ALPHA: numChannels = 1; break; default: // fall out @@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator { int height = image.getHeight(); return createImage(buffer, width, height, numChannels); } - if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { + if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) { Bitmap bitmap = BitmapExtractor.extract(image); if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BUILD b/mediapipe/java/com/google/mediapipe/framework/image/BUILD index abf82a892..bb3be318d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/image/BUILD @@ -30,3 +30,10 @@ android_library( "@maven//:com_google_guava_guava", ], ) + +# Expose the java source files for building mediapipe AAR. +filegroup( + name = "java_src", + srcs = glob(["*.java"]), + visibility = ["//mediapipe:__subpackages__"], +) diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java index 4c6cebd4d..d6f50bf30 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapExtractor.java @@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image; import android.graphics.Bitmap; /** - * Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. + * Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}. * - *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise + *

Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise * {@link IllegalArgumentException} will be thrown. */ public final class BitmapExtractor { /** - * Extracts a {@link android.graphics.Bitmap} from an {@link Image}. + * Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}. * * @param image the image to extract {@link android.graphics.Bitmap} from. - * @return the {@link android.graphics.Bitmap} stored in {@link Image} + * @return the {@link android.graphics.Bitmap} stored in {@link MPImage} * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - public static Bitmap extract(Image image) { - ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); + public static Bitmap extract(MPImage image) { + MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP); if (imageContainer != null) { return ((BitmapImageContainer) imageContainer).getBitmap(); } else { // TODO: Support ByteBuffer -> Bitmap conversion. throw new IllegalArgumentException( - "Extracting Bitmap from an Image created by objects other than Bitmap is not" + "Extracting Bitmap from a MPImage created by objects other than Bitmap is not" + " supported"); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java index ea2ca6b1f..988cdf542 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageBuilder.java @@ -22,7 +22,7 @@ import android.provider.MediaStore; import java.io.IOException; /** - * Builds {@link Image} from {@link android.graphics.Bitmap}. + * Builds {@link MPImage} from {@link android.graphics.Bitmap}. * *

You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content @@ -49,7 +49,7 @@ public class BitmapImageBuilder { } /** - * Creates the builder to build {@link Image} from a file. + * Creates the builder to build {@link MPImage} from a file. * * @param context the application context. * @param uri the path to the resource file. @@ -58,15 +58,15 @@ public class BitmapImageBuilder { this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ BitmapImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image( + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage( new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java index 0457e1e9b..6fbcac214 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/BitmapImageContainer.java @@ -16,19 +16,19 @@ limitations under the License. package com.google.mediapipe.framework.image; import android.graphics.Bitmap; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; -class BitmapImageContainer implements ImageContainer { +class BitmapImageContainer implements MPImageContainer { private final Bitmap bitmap; - private final ImageProperties properties; + private final MPImageProperties properties; public BitmapImageContainer(Bitmap bitmap) { this.bitmap = bitmap; this.properties = - ImageProperties.builder() + MPImageProperties.builder() .setImageFormat(convertFormatCode(bitmap.getConfig())) - .setStorageType(Image.STORAGE_TYPE_BITMAP) + .setStorageType(MPImage.STORAGE_TYPE_BITMAP) .build(); } @@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer { } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } @@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer { bitmap.recycle(); } - @ImageFormat + @MPImageFormat static int convertFormatCode(Bitmap.Config config) { switch (config) { case ALPHA_8: - return Image.IMAGE_FORMAT_ALPHA; + return MPImage.IMAGE_FORMAT_ALPHA; case ARGB_8888: - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; default: - return Image.IMAGE_FORMAT_UNKNOWN; + return MPImage.IMAGE_FORMAT_UNKNOWN; } } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java index a0e8c3dff..748a10667 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferExtractor.java @@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config; import android.os.Build.VERSION; import android.os.Build.VERSION_CODES; import com.google.auto.value.AutoValue; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Locale; /** - * Utility for extracting {@link ByteBuffer} from {@link Image}. + * Utility for extracting {@link ByteBuffer} from {@link MPImage}. * - *

Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise - * {@link IllegalArgumentException} will be thrown. + *

Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER}, + * otherwise {@link IllegalArgumentException} will be thrown. */ public class ByteBufferExtractor { /** - * Extracts a {@link ByteBuffer} from an {@link Image}. + * Extracts a {@link ByteBuffer} from a {@link MPImage}. * *

The returned {@link ByteBuffer} is a read-only view, with the first available {@link - * ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. + * MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}. * - * @see Image#getContainedImageProperties() + * @see MPImage#getContainedImageProperties() * @return A read-only {@link ByteBuffer}. * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. */ @SuppressLint("SwitchIntDef") - public static ByteBuffer extract(Image image) { - ImageContainer container = image.getContainer(); + public static ByteBuffer extract(MPImage image) { + MPImageContainer container = image.getContainer(); switch (container.getImageProperties().getStorageType()) { - case Image.STORAGE_TYPE_BYTEBUFFER: + case MPImage.STORAGE_TYPE_BYTEBUFFER: ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); default: throw new IllegalArgumentException( - "Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" + "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not" + " supported"); } } /** - * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. + * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}. * *

Format conversion spec: * @@ -70,26 +70,26 @@ public class ByteBufferExtractor { * * @param image the image to extract buffer from. * @param targetFormat the image format of the result bytebuffer. - * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @return the readonly {@link ByteBuffer} stored in {@link MPImage} * @throws IllegalArgumentException when the extraction requires unsupported format or data type * conversions. */ - static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { - ImageContainer container; - ImageProperties byteBufferProperties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) { + MPImageContainer container; + MPImageProperties byteBufferProperties = + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER) .setImageFormat(targetFormat) .build(); if ((container = image.getContainer(byteBufferProperties)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; - @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); + @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) .asReadOnlyBuffer(); - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; ByteBuffer byteBuffer = extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) @@ -98,85 +98,89 @@ public class ByteBufferExtractor { return byteBuffer; } else { throw new IllegalArgumentException( - "Extracting ByteBuffer from an Image created by objects other than Bitmap or" + "Extracting ByteBuffer from a MPImage created by objects other than Bitmap or" + " Bytebuffer is not supported"); } } - /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ + /** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */ @AutoValue abstract static class Result { - /** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ + /** + * Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}. + */ public abstract ByteBuffer buffer(); - /** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ - @ImageFormat + /** + * Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}. + */ + @MPImageFormat public abstract int format(); - static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { + static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) { return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); } } /** - * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. + * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}. * *

It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. * - * @return the readonly {@link ByteBuffer} stored in {@link Image} + * @return the readonly {@link ByteBuffer} stored in {@link MPImage} * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with * given {@code imageFormat} */ - static Result extractInRecommendedFormat(Image image) { - ImageContainer container; - if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { + static Result extractInRecommendedFormat(MPImage image) { + MPImageContainer container; + if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) { Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); - @ImageFormat int format = adviseImageFormat(bitmap); + @MPImageFormat int format = adviseImageFormat(bitmap); Result result = Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); boolean unused = image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); return result; - } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { + } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) { ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; return Result.create( byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), byteBufferImageContainer.getImageFormat()); } else { throw new IllegalArgumentException( - "Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" + "Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer" + " is not supported"); } } - @ImageFormat + @MPImageFormat private static int adviseImageFormat(Bitmap bitmap) { if (bitmap.getConfig() == Config.ARGB_8888) { - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; } else { throw new IllegalArgumentException( String.format( - "Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" + "Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not" + " supported", bitmap.getConfig())); } } private static ByteBuffer extractByteBufferFromBitmap( - Bitmap bitmap, @ImageFormat int imageFormat) { + Bitmap bitmap, @MPImageFormat int imageFormat) { if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { throw new IllegalArgumentException( - "Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" + "Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not" + " supported"); } if (bitmap.getConfig() == Config.ARGB_8888) { - if (imageFormat == Image.IMAGE_FORMAT_RGBA) { + if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) { ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); bitmap.copyPixelsToBuffer(buffer); buffer.rewind(); return buffer; - } else if (imageFormat == Image.IMAGE_FORMAT_RGB) { + } else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) { // TODO: Try Use RGBA buffer to create RGB buffer which might be faster. int w = bitmap.getWidth(); int h = bitmap.getHeight(); @@ -196,14 +200,14 @@ public class ByteBufferExtractor { } throw new IllegalArgumentException( String.format( - "Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" + "Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format" + " %d is not supported", bitmap.getConfig(), imageFormat)); } private static ByteBuffer convertByteBuffer( - ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { - if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { + ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) { + if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) { ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); // Extend the buffer when the target is longer than the source. Use two cursors and sweep the // array reversely to convert in-place. @@ -221,7 +225,8 @@ public class ByteBufferExtractor { target.put(array, 0, target.capacity()); target.rewind(); return target; - } else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { + } else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA + && targetFormat == MPImage.IMAGE_FORMAT_RGB) { ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the // array to convert in-place. diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java index 07871da38..a650e4c33 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageBuilder.java @@ -15,11 +15,11 @@ limitations under the License. package com.google.mediapipe.framework.image; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; /** - * Builds a {@link Image} from a {@link ByteBuffer}. + * Builds a {@link MPImage} from a {@link ByteBuffer}. * *

You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. @@ -32,7 +32,7 @@ public class ByteBufferImageBuilder { private final ByteBuffer buffer; private final int width; private final int height; - @ImageFormat private final int imageFormat; + @MPImageFormat private final int imageFormat; // Optional fields. private long timestamp; @@ -49,7 +49,7 @@ public class ByteBufferImageBuilder { * @param imageFormat how the data encode the image. */ public ByteBufferImageBuilder( - ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { + ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) { this.buffer = byteBuffer; this.width = width; this.height = height; @@ -58,14 +58,14 @@ public class ByteBufferImageBuilder { this.timestamp = 0; } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ ByteBufferImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java index 1c24c1dfd..82dbe32ca 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/ByteBufferImageContainer.java @@ -15,21 +15,19 @@ limitations under the License. package com.google.mediapipe.framework.image; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; import java.nio.ByteBuffer; -class ByteBufferImageContainer implements ImageContainer { +class ByteBufferImageContainer implements MPImageContainer { private final ByteBuffer buffer; - private final ImageProperties properties; + private final MPImageProperties properties; - public ByteBufferImageContainer( - ByteBuffer buffer, - @ImageFormat int imageFormat) { + public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) { this.buffer = buffer; this.properties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER) .setImageFormat(imageFormat) .build(); } @@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer { } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } - /** - * Returns the image format. - */ - @ImageFormat + /** Returns the image format. */ + @MPImageFormat public int getImageFormat() { return properties.getImageFormat(); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/Image.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java similarity index 76% rename from mediapipe/java/com/google/mediapipe/framework/image/Image.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImage.java index 49e63bcc0..e17cc4d30 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/Image.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImage.java @@ -29,10 +29,10 @@ import java.util.Map.Entry; /** * The wrapper class for image objects. * - *

{@link Image} is designed to be an immutable image container, which could be shared + *

{@link MPImage} is designed to be an immutable image container, which could be shared * cross-platforms. * - *

To construct an {@link Image}, use the provided builders: + *

To construct a {@link MPImage}, use the provided builders: * *

    *
  • {@link ByteBufferImageBuilder} @@ -40,7 +40,7 @@ import java.util.Map.Entry; *
  • {@link MediaImageBuilder} *
* - *

{@link Image} uses reference counting to maintain internal storage. When it is created the + *

{@link MPImage} uses reference counting to maintain internal storage. When it is created the * reference count is 1. Developer can call {@link #close()} to reduce reference count to release * internal storage earlier, otherwise Java garbage collection will release the storage eventually. * @@ -53,7 +53,7 @@ import java.util.Map.Entry; *

  • {@link MediaImageExtractor} * */ -public class Image implements Closeable { +public class MPImage implements Closeable { /** Specifies the image format of an image. */ @IntDef({ @@ -69,7 +69,7 @@ public class Image implements Closeable { IMAGE_FORMAT_JPEG, }) @Retention(RetentionPolicy.SOURCE) - public @interface ImageFormat {} + public @interface MPImageFormat {} public static final int IMAGE_FORMAT_UNKNOWN = 0; public static final int IMAGE_FORMAT_RGBA = 1; @@ -98,14 +98,14 @@ public class Image implements Closeable { public static final int STORAGE_TYPE_IMAGE_PROXY = 4; /** - * Returns a list of supported image properties for this {@link Image}. + * Returns a list of supported image properties for this {@link MPImage}. * - *

    Currently {@link Image} only support single storage type so the size of return list will + *

    Currently {@link MPImage} only support single storage type so the size of return list will * always be 1. * - * @see ImageProperties + * @see MPImageProperties */ - public List getContainedImageProperties() { + public List getContainedImageProperties() { return Collections.singletonList(getContainer().getImageProperties()); } @@ -124,7 +124,7 @@ public class Image implements Closeable { return height; } - /** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ + /** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */ private synchronized void acquire() { referenceCount += 1; } @@ -132,7 +132,7 @@ public class Image implements Closeable { /** * Removes a reference that was previously acquired or init. * - *

    When {@link Image} is created, it has 1 reference count. + *

    When {@link MPImage} is created, it has 1 reference count. * *

    When the reference count becomes 0, it will release the resource under the hood. */ @@ -141,24 +141,24 @@ public class Image implements Closeable { public synchronized void close() { referenceCount -= 1; if (referenceCount == 0) { - for (ImageContainer imageContainer : containerMap.values()) { + for (MPImageContainer imageContainer : containerMap.values()) { imageContainer.close(); } } } - /** Advanced API access for {@link Image}. */ + /** Advanced API access for {@link MPImage}. */ static final class Internal { /** - * Acquires a reference on this {@link Image}. This will increase the reference count by 1. + * Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. * *

    This method is more useful for image consumer to acquire a reference so image resource * will not be closed accidentally. As image creator, normal developer doesn't need to call this * method. * - *

    The reference count is 1 when {@link Image} is created. Developer can call {@link - * #close()} to indicate it doesn't need this {@link Image} anymore. + *

    The reference count is 1 when {@link MPImage} is created. Developer can call {@link + * #close()} to indicate it doesn't need this {@link MPImage} anymore. * * @see #close() */ @@ -166,10 +166,10 @@ public class Image implements Closeable { image.acquire(); } - private final Image image; + private final MPImage image; - // Only Image creates the internal helper. - private Internal(Image image) { + // Only MPImage creates the internal helper. + private Internal(MPImage image) { this.image = image; } } @@ -179,15 +179,15 @@ public class Image implements Closeable { return new Internal(this); } - private final Map containerMap; + private final Map containerMap; private final long timestamp; private final int width; private final int height; private int referenceCount; - /** Constructs an {@link Image} with a built container. */ - Image(ImageContainer container, long timestamp, int width, int height) { + /** Constructs a {@link MPImage} with a built container. */ + MPImage(MPImageContainer container, long timestamp, int width, int height) { this.containerMap = new HashMap<>(); containerMap.put(container.getImageProperties(), container); this.timestamp = timestamp; @@ -201,10 +201,10 @@ public class Image implements Closeable { * * @return the current container. */ - ImageContainer getContainer() { + MPImageContainer getContainer() { // According to the design, in the future we will support multiple containers in one image. // Currently just return the original container. - // TODO: Cache multiple containers in Image. + // TODO: Cache multiple containers in MPImage. return containerMap.values().iterator().next(); } @@ -214,8 +214,8 @@ public class Image implements Closeable { *

    If there are multiple containers with required {@code storageType}, returns the first one. */ @Nullable - ImageContainer getContainer(@StorageType int storageType) { - for (Entry entry : containerMap.entrySet()) { + MPImageContainer getContainer(@StorageType int storageType) { + for (Entry entry : containerMap.entrySet()) { if (entry.getKey().getStorageType() == storageType) { return entry.getValue(); } @@ -225,13 +225,13 @@ public class Image implements Closeable { /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ @Nullable - ImageContainer getContainer(ImageProperties imageProperties) { + MPImageContainer getContainer(MPImageProperties imageProperties) { return containerMap.get(imageProperties); } /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ - boolean addContainer(ImageContainer container) { - ImageProperties imageProperties = container.getImageProperties(); + boolean addContainer(MPImageContainer container) { + MPImageProperties imageProperties = container.getImageProperties(); if (containerMap.containsKey(imageProperties)) { return false; } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java similarity index 87% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java index 18eed68c6..f9f343e93 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageConsumer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageConsumer.java @@ -14,14 +14,14 @@ limitations under the License. ==============================================================================*/ package com.google.mediapipe.framework.image; -/** Lightweight abstraction for an object that can receive {@link Image} */ -public interface ImageConsumer { +/** Lightweight abstraction for an object that can receive {@link MPImage} */ +public interface MPImageConsumer { /** - * Called when an {@link Image} is available. + * Called when a {@link MPImage} is available. * *

    The argument is only guaranteed to be available until this method returns. if you need to * extend its life time, acquire it, then release it when done. */ - void onNewImage(Image image); + void onNewMPImage(MPImage image); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java similarity index 93% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java index 727ec0893..674073b5b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageContainer.java @@ -16,9 +16,9 @@ limitations under the License. package com.google.mediapipe.framework.image; /** Manages internal image data storage. The interface is package-private. */ -interface ImageContainer { +interface MPImageContainer { /** Returns the properties of the contained image. */ - ImageProperties getImageProperties(); + MPImageProperties getImageProperties(); /** Close the image container and releases the image resource inside. */ void close(); diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java similarity index 75% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java index 4f3641d6f..9783935d4 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageProducer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProducer.java @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ package com.google.mediapipe.framework.image; -/** Lightweight abstraction for an object that produce {@link Image} */ -public interface ImageProducer { +/** Lightweight abstraction for an object that produce {@link MPImage} */ +public interface MPImageProducer { - /** Sets the consumer that receives the {@link Image}. */ - void setImageConsumer(ImageConsumer imageConsumer); + /** Sets the consumer that receives the {@link MPImage}. */ + void setMPImageConsumer(MPImageConsumer imageConsumer); } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java similarity index 63% rename from mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java rename to mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java index e33b33e7f..6005ce77b 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/ImageProperties.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MPImageProperties.java @@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image; import com.google.auto.value.AutoValue; import com.google.auto.value.extension.memoized.Memoized; -import com.google.mediapipe.framework.image.Image.ImageFormat; -import com.google.mediapipe.framework.image.Image.StorageType; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; +import com.google.mediapipe.framework.image.MPImage.StorageType; /** Groups a set of properties to describe how an image is stored. */ @AutoValue -public abstract class ImageProperties { +public abstract class MPImageProperties { /** * Gets the pixel format of the image. * - * @see Image.ImageFormat + * @see MPImage.MPImageFormat */ - @ImageFormat + @MPImageFormat public abstract int getImageFormat(); /** * Gets the storage type of the image. * - * @see Image.StorageType + * @see MPImage.StorageType */ @StorageType public abstract int getStorageType(); @@ -45,36 +45,36 @@ public abstract class ImageProperties { public abstract int hashCode(); /** - * Creates a builder of {@link ImageProperties}. + * Creates a builder of {@link MPImageProperties}. * - * @see ImageProperties.Builder + * @see MPImageProperties.Builder */ static Builder builder() { - return new AutoValue_ImageProperties.Builder(); + return new AutoValue_MPImageProperties.Builder(); } - /** Builds a {@link ImageProperties}. */ + /** Builds a {@link MPImageProperties}. */ @AutoValue.Builder abstract static class Builder { /** - * Sets the {@link Image.ImageFormat}. + * Sets the {@link MPImage.MPImageFormat}. * - * @see ImageProperties#getImageFormat + * @see MPImageProperties#getImageFormat */ - abstract Builder setImageFormat(@ImageFormat int value); + abstract Builder setImageFormat(@MPImageFormat int value); /** - * Sets the {@link Image.StorageType}. + * Sets the {@link MPImage.StorageType}. * - * @see ImageProperties#getStorageType + * @see MPImageProperties#getStorageType */ abstract Builder setStorageType(@StorageType int value); - /** Builds the {@link ImageProperties}. */ - abstract ImageProperties build(); + /** Builds the {@link MPImageProperties}. */ + abstract MPImageProperties build(); } // Hide the constructor. - ImageProperties() {} + MPImageProperties() {} } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java index e351a87fd..9e719715d 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageBuilder.java @@ -15,11 +15,12 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; /** - * Builds {@link Image} from {@link android.media.Image}. + * Builds {@link MPImage} from {@link android.media.Image}. * *

    Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify * content in it. @@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi; public class MediaImageBuilder { // Mandatory fields. - private final android.media.Image mediaImage; + private final Image mediaImage; // Optional fields. private long timestamp; @@ -40,20 +41,20 @@ public class MediaImageBuilder { * * @param mediaImage image data object. */ - public MediaImageBuilder(android.media.Image mediaImage) { + public MediaImageBuilder(Image mediaImage) { this.mediaImage = mediaImage; this.timestamp = 0; } - /** Sets value for {@link Image#getTimestamp()}. */ + /** Sets value for {@link MPImage#getTimestamp()}. */ MediaImageBuilder setTimestamp(long timestamp) { this.timestamp = timestamp; return this; } - /** Builds an {@link Image} instance. */ - public Image build() { - return new Image( + /** Builds a {@link MPImage} instance. */ + public MPImage build() { + return new MPImage( new MediaImageContainer(mediaImage), timestamp, mediaImage.getWidth(), diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java index 144b64def..864c76df2 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageContainer.java @@ -15,33 +15,34 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build; import android.os.Build.VERSION; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; -import com.google.mediapipe.framework.image.Image.ImageFormat; +import com.google.mediapipe.framework.image.MPImage.MPImageFormat; @RequiresApi(VERSION_CODES.KITKAT) -class MediaImageContainer implements ImageContainer { +class MediaImageContainer implements MPImageContainer { - private final android.media.Image mediaImage; - private final ImageProperties properties; + private final Image mediaImage; + private final MPImageProperties properties; - public MediaImageContainer(android.media.Image mediaImage) { + public MediaImageContainer(Image mediaImage) { this.mediaImage = mediaImage; this.properties = - ImageProperties.builder() - .setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) + MPImageProperties.builder() + .setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE) .setImageFormat(convertFormatCode(mediaImage.getFormat())) .build(); } - public android.media.Image getImage() { + public Image getImage() { return mediaImage; } @Override - public ImageProperties getImageProperties() { + public MPImageProperties getImageProperties() { return properties; } @@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer { mediaImage.close(); } - @ImageFormat + @MPImageFormat static int convertFormatCode(int graphicsFormat) { // We only cover the format mentioned in // https://developer.android.com/reference/android/media/Image#getFormat() if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { - return Image.IMAGE_FORMAT_RGBA; + return MPImage.IMAGE_FORMAT_RGBA; } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { - return Image.IMAGE_FORMAT_RGB; + return MPImage.IMAGE_FORMAT_RGB; } } switch (graphicsFormat) { case android.graphics.ImageFormat.JPEG: - return Image.IMAGE_FORMAT_JPEG; + return MPImage.IMAGE_FORMAT_JPEG; case android.graphics.ImageFormat.YUV_420_888: - return Image.IMAGE_FORMAT_YUV_420_888; + return MPImage.IMAGE_FORMAT_YUV_420_888; default: - return Image.IMAGE_FORMAT_UNKNOWN; + return MPImage.IMAGE_FORMAT_UNKNOWN; } } } diff --git a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java index 718cb471f..76bb5a5ec 100644 --- a/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java +++ b/mediapipe/java/com/google/mediapipe/framework/image/MediaImageExtractor.java @@ -15,13 +15,14 @@ limitations under the License. package com.google.mediapipe.framework.image; +import android.media.Image; import android.os.Build.VERSION_CODES; import androidx.annotation.RequiresApi; /** - * Utility for extracting {@link android.media.Image} from {@link Image}. + * Utility for extracting {@link android.media.Image} from {@link MPImage}. * - *

    Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, + *

    Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE}, * otherwise {@link IllegalArgumentException} will be thrown. */ @RequiresApi(VERSION_CODES.KITKAT) @@ -30,20 +31,20 @@ public class MediaImageExtractor { private MediaImageExtractor() {} /** - * Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for - * {@link Image} that built from {@link MediaImageBuilder}. + * Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for + * {@link MPImage} that built from {@link MediaImageBuilder}. * * @param image the image to extract {@link android.media.Image} from. - * @return {@link android.media.Image} that stored in {@link Image}. + * @return {@link android.media.Image} that stored in {@link MPImage}. * @throws IllegalArgumentException if the extraction failed. */ - public static android.media.Image extract(Image image) { - ImageContainer container; - if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { + public static Image extract(MPImage image) { + MPImageContainer container; + if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) { return ((MediaImageContainer) container).getImage(); } throw new IllegalArgumentException( - "Extract Media Image from an Image created by objects other than Media Image" + "Extract Media Image from a MPImage created by objects other than Media Image" + " is not supported"); } } diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index ed1686954..645e8b722 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -1,4 +1,4 @@ -# Copyright 2019-2020 The MediaPipe Authors. +# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -89,10 +89,6 @@ def mediapipe_aar( calculators = calculators, ) - _mediapipe_proto( - name = name + "_proto", - ) - native.genrule( name = name + "_aar_manifest_generator", outs = ["AndroidManifest.xml"], @@ -115,19 +111,10 @@ EOF "//mediapipe/java/com/google/mediapipe/components:java_src", "//mediapipe/java/com/google/mediapipe/framework:java_src", "//mediapipe/java/com/google/mediapipe/glutil:java_src", - "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", - "com/google/mediapipe/formats/proto/ClassificationProto.java", - "com/google/mediapipe/formats/proto/DetectionProto.java", - "com/google/mediapipe/formats/proto/LandmarkProto.java", - "com/google/mediapipe/formats/proto/LocationDataProto.java", - "com/google/mediapipe/proto/CalculatorProto.java", - ] + + ] + mediapipe_java_proto_srcs() + select({ "//conditions:default": [], - "enable_stats_logging": [ - "com/google/mediapipe/proto/MediaPipeLoggingProto.java", - "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", - ], + "enable_stats_logging": mediapipe_logging_java_proto_srcs(), }), manifest = "AndroidManifest.xml", proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], @@ -177,93 +164,9 @@ EOF assets_dir = assets_dir, ) - _aar_with_jni(name, name + "_android_lib") - -def _mediapipe_proto(name): - """Generates MediaPipe java proto libraries. - - Args: - name: the name of the target. - """ - _proto_java_src_generator( - name = "mediapipe_log_extension_proto", - proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto", - java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java", - srcs = ["//mediapipe/util/analytics:protos_src"], - ) - - _proto_java_src_generator( - name = "mediapipe_logging_enums_proto", - proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto", - java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", - srcs = ["//mediapipe/util/analytics:protos_src"], - ) - - _proto_java_src_generator( - name = "calculator_proto", - proto_src = "mediapipe/framework/calculator.proto", - java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java", - srcs = ["//mediapipe/framework:protos_src"], - ) - - _proto_java_src_generator( - name = "landmark_proto", - proto_src = "mediapipe/framework/formats/landmark.proto", - java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", - srcs = ["//mediapipe/framework/formats:protos_src"], - ) - - _proto_java_src_generator( - name = "rasterization_proto", - proto_src = "mediapipe/framework/formats/annotation/rasterization.proto", - java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", - srcs = ["//mediapipe/framework/formats/annotation:protos_src"], - ) - - _proto_java_src_generator( - name = "location_data_proto", - proto_src = "mediapipe/framework/formats/location_data.proto", - java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - "//mediapipe/framework/formats/annotation:protos_src", - ], - ) - - _proto_java_src_generator( - name = "detection_proto", - proto_src = "mediapipe/framework/formats/detection.proto", - java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - "//mediapipe/framework/formats/annotation:protos_src", - ], - ) - - _proto_java_src_generator( - name = "classification_proto", - proto_src = "mediapipe/framework/formats/classification.proto", - java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", - srcs = [ - "//mediapipe/framework/formats:protos_src", - ], - ) - -def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []): - native.genrule( - name = name + "_proto_java_src_generator", - srcs = srcs + [ - "@com_google_protobuf//:lite_well_known_protos", - ], - outs = [java_lite_out], - cmd = "$(location @com_google_protobuf//:protoc) " + - "--proto_path=. --proto_path=$(GENDIR) " + - "--proto_path=$$(pwd)/external/com_google_protobuf/src " + - "--java_out=lite:$(GENDIR) " + proto_src + " && " + - "mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))", - tools = [ - "@com_google_protobuf//:protoc", - ], + mediapipe_build_aar_with_jni( + name = name, + android_library = name + "_android_lib", ) def _mediapipe_jni(name, gen_libmediapipe, calculators = []): @@ -303,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []): alwayslink = 1, ) -def _aar_with_jni(name, android_library): +def mediapipe_build_aar_with_jni(name, android_library): + """Builds MediaPipe AAR with jni. + + Args: + name: The bazel target name. + android_library: the android library that contains jni. + """ + # Generates dummy AndroidManifest.xml for dummy apk usage # (dummy apk is generated by _dummy_app target below) native.genrule( @@ -314,7 +224,7 @@ cat > $(OUTS) < - + EOF """, @@ -341,7 +251,133 @@ chmod +w $(location :{}.aar) origdir=$$PWD cd $$(mktemp -d) unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*" +find lib -name *_dummy_app.so -delete cp -r lib jni zip -r $$origdir/$(location :{}.aar) jni/*/*.so """.format(android_library, name, name, name, name), ) + +def mediapipe_java_proto_src_extractor(target, src_out, name = ""): + """Extracts the generated MediaPipe java proto source code from the target. + + Args: + target: The java proto lite target to be built and extracted. + src_out: The output java proto src code path. + name: The optional bazel target name. + + Returns: + The output java proto src code path. + """ + + if not name: + name = target.split(":")[-1] + "_proto_java_src_extractor" + src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "") + native.genrule( + name = name + "_proto_java_src_extractor", + srcs = [target], + outs = [src_out], + cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" + + src_out + " $$(dirname $(location " + src_out + "))", + ) + return src_out + +def mediapipe_java_proto_srcs(name = ""): + """Extracts the generated MediaPipe framework java proto source code. + + Args: + name: The optional bazel target name. + + Returns: + The list of the extrated MediaPipe java proto source code. + """ + + proto_src_list = [] + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:calculator_java_proto_lite", + src_out = "com/google/mediapipe/proto/CalculatorProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:calculator_options_java_proto_lite", + src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:stream_handler_java_proto_lite", + src_out = "com/google/mediapipe/proto/StreamHandlerProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:packet_factory_java_proto_lite", + src_out = "com/google/mediapipe/proto/PacketFactoryProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:packet_generator_java_proto_lite", + src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:status_handler_java_proto_lite", + src_out = "com/google/mediapipe/proto/StatusHandlerProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework:mediapipe_options_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite", + src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:classification_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:detection_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/DetectionProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:landmark_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:location_data_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/framework/formats:rect_java_proto_lite", + src_out = "com/google/mediapipe/formats/proto/RectProto.java", + )) + return proto_src_list + +def mediapipe_logging_java_proto_srcs(name = ""): + """Extracts the generated logging-related MediaPipe java proto source code. + + Args: + name: The optional bazel target name. + + Returns: + The list of the extrated MediaPipe logging-related java proto source code. + """ + + proto_src_list = [] + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite", + src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java", + )) + return proto_src_list diff --git a/mediapipe/model_maker/BUILD b/mediapipe/model_maker/BUILD new file mode 100644 index 000000000..cb312072f --- /dev/null +++ b/mediapipe/model_maker/BUILD @@ -0,0 +1,22 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +package_group( + name = "internal", + packages = [ + "//mediapipe/model_maker/...", + ], +) diff --git a/mediapipe/model_maker/__init__.py b/mediapipe/model_maker/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/BUILD b/mediapipe/model_maker/python/BUILD new file mode 100644 index 000000000..cb312072f --- /dev/null +++ b/mediapipe/model_maker/python/BUILD @@ -0,0 +1,22 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +package_group( + name = "internal", + packages = [ + "//mediapipe/model_maker/...", + ], +) diff --git a/mediapipe/model_maker/python/__init__.py b/mediapipe/model_maker/python/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/BUILD b/mediapipe/model_maker/python/core/BUILD new file mode 100644 index 000000000..636a1a720 --- /dev/null +++ b/mediapipe/model_maker/python/core/BUILD @@ -0,0 +1,26 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) + +py_library( + name = "hyperparameters", + srcs = ["hyperparameters.py"], +) diff --git a/mediapipe/model_maker/python/core/__init__.py b/mediapipe/model_maker/python/core/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/data/BUILD b/mediapipe/model_maker/python/core/data/BUILD new file mode 100644 index 000000000..70a62e8f7 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/BUILD @@ -0,0 +1,61 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "data_util", + srcs = ["data_util.py"], +) + +py_test( + name = "data_util_test", + srcs = ["data_util_test.py"], + data = ["//mediapipe/model_maker/python/core/data/testdata"], + deps = [":data_util"], +) + +py_library( + name = "dataset", + srcs = ["dataset.py"], + srcs_version = "PY3", +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + deps = [ + ":dataset", + "//mediapipe/model_maker/python/core/utils:test_util", + ], +) + +py_library( + name = "classification_dataset", + srcs = ["classification_dataset.py"], + deps = [":dataset"], +) + +py_test( + name = "classification_dataset_test", + srcs = ["classification_dataset_test.py"], + deps = [":classification_dataset"], +) diff --git a/mediapipe/model_maker/python/core/data/__init__.py b/mediapipe/model_maker/python/core/data/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/data/classification_dataset.py b/mediapipe/model_maker/python/core/data/classification_dataset.py new file mode 100644 index 000000000..af761d9ea --- /dev/null +++ b/mediapipe/model_maker/python/core/data/classification_dataset.py @@ -0,0 +1,51 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common classification dataset library.""" + +from typing import Any, Tuple + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds + + +class ClassificationDataset(ds.Dataset): + """DataLoader for classification models.""" + + def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any): + super().__init__(dataset, size) + self._index_by_label = index_by_label + + @property + def num_classes(self: ds._DatasetT) -> int: + return len(self._index_by_label) + + @property + def index_by_label(self: ds._DatasetT) -> Any: + return self._index_by_label + + def split(self: ds._DatasetT, + fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]: + """Splits dataset into two sub-datasets with the given fraction. + + Primarily used for splitting the data set into training and testing sets. + + Args: + fraction: float, demonstrates the fraction of the first returned + subdataset in the original data. + + Returns: + The splitted two sub datasets. + """ + return self._split(fraction, self._index_by_label) diff --git a/mediapipe/model_maker/python/core/data/classification_dataset_test.py b/mediapipe/model_maker/python/core/data/classification_dataset_test.py new file mode 100644 index 000000000..0fd8575f4 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/classification_dataset_test.py @@ -0,0 +1,82 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple, TypeVar + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset + +_DatasetT = TypeVar( + '_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset') + + +class ClassificationDatasetTest(tf.test.TestCase): + + def test_split(self): + + class MagicClassificationDataset( + classification_dataset.ClassificationDataset): + """A mock classification dataset class for testing purpose. + + Attributes: + value: A value variable stored by the mock dataset class for testing. + """ + + def __init__(self, dataset: tf.data.Dataset, size: int, + index_by_label: Any, value: Any): + super().__init__( + dataset=dataset, size=size, index_by_label=index_by_label) + self.value = value + + def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]: + return self._split(fraction, self.index_by_label, self.value) + + # Some dummy inputs. + magic_value = 42 + num_classes = 2 + index_by_label = (False, True) + + # Create data loader from sample data. + ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) + data = MagicClassificationDataset( + dataset=ds, + size=len(ds), + index_by_label=index_by_label, + value=magic_value) + + # Train/Test data split. + fraction = .25 + train_data, test_data = data.split(fraction=fraction) + + # `split` should return instances of child DataLoader. + self.assertIsInstance(train_data, MagicClassificationDataset) + self.assertIsInstance(test_data, MagicClassificationDataset) + + # Make sure number of entries are right. + self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data)) + self.assertLen(train_data, fraction * len(ds)) + self.assertLen(test_data, len(ds) - len(train_data)) + + # Make sure attributes propagated correctly. + self.assertEqual(train_data.num_classes, num_classes) + self.assertEqual(test_data.index_by_label, index_by_label) + self.assertEqual(train_data.value, magic_value) + self.assertEqual(test_data.value, magic_value) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/data_util.py b/mediapipe/model_maker/python/core/data/data_util.py new file mode 100644 index 000000000..8c6b9145f --- /dev/null +++ b/mediapipe/model_maker/python/core/data/data_util.py @@ -0,0 +1,35 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Data utility library.""" + +import cv2 +import numpy as np +import tensorflow as tf + + +def load_image(path: str) -> np.ndarray: + """Loads an image as an RGB numpy array. + + Args: + path: input image file absolute path. + + Returns: + An RGB image in numpy.ndarray. + """ + tf.compat.v1.logging.info('Loading RGB image %s', path) + # TODO Replace the OpenCV image load and conversion library by + # MediaPipe image utility library once it is ready. + image = cv2.imread(path) + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) diff --git a/mediapipe/model_maker/python/core/data/data_util_test.py b/mediapipe/model_maker/python/core/data/data_util_test.py new file mode 100644 index 000000000..56ac832c3 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/data_util_test.py @@ -0,0 +1,44 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +from absl import flags +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import data_util + +_WORKSPACE = "mediapipe" +_TEST_DATA_DIR = os.path.join( + _WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata') + +FLAGS = flags.FLAGS + + +class DataUtilTest(tf.test.TestCase): + + def test_load_rgb_image(self): + image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg') + image_data = data_util.load_image(image_path) + self.assertEqual(image_data.shape, (5184, 3456, 3)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/dataset.py b/mediapipe/model_maker/python/core/data/dataset.py new file mode 100644 index 000000000..a92b05c0d --- /dev/null +++ b/mediapipe/model_maker/python/core/data/dataset.py @@ -0,0 +1,164 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common dataset for model training and evaluation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from typing import Callable, Optional, Tuple, TypeVar + +# Dependency imports +import tensorflow as tf + +_DatasetT = TypeVar('_DatasetT', bound='Dataset') + + +class Dataset(object): + """A generic dataset class for loading model training and evaluation dataset. + + For each ML task, such as image classification, text classification etc., a + subclass can be derived from this class to provide task-specific data loading + utilities. + """ + + def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None): + """Initializes Dataset class. + + To build dataset from raw data, consider using the task specific utilities, + e.g. from_folder(). + + Args: + tf_dataset: A tf.data.Dataset object that contains a potentially large set + of elements, where each element is a pair of (input_data, target). The + `input_data` means the raw input data, like an image, a text etc., while + the `target` means the ground truth of the raw input data, e.g. the + classification label of the image etc. + size: The size of the dataset. tf.data.Dataset donesn't support a function + to get the length directly since it's lazy-loaded and may be infinite. + """ + self._dataset = tf_dataset + self._size = size + + @property + def size(self) -> Optional[int]: + """Returns the size of the dataset. + + Note that this function may return None becuase the exact size of the + dataset isn't a necessary parameter to create an instance of this class, + and tf.data.Dataset donesn't support a function to get the length directly + since it's lazy-loaded and may be infinite. + In most cases, however, when an instance of this class is created by helper + functions like 'from_folder', the size of the dataset will be preprocessed, + and this function can return an int representing the size of the dataset. + """ + return self._size + + def gen_tf_dataset(self, + batch_size: int = 1, + is_training: bool = False, + shuffle: bool = False, + preprocess: Optional[Callable[..., bool]] = None, + drop_remainder: bool = False) -> tf.data.Dataset: + """Generates a batched tf.data.Dataset for training/evaluation. + + Args: + batch_size: An integer, the returned dataset will be batched by this size. + is_training: A boolean, when True, the returned dataset will be optionally + shuffled and repeated as an endless dataset. + shuffle: A boolean, when True, the returned dataset will be shuffled to + create randomness during model training. + preprocess: A function taking three arguments in order, feature, label and + boolean is_training. + drop_remainder: boolean, whether the finaly batch drops remainder. + + Returns: + A TF dataset ready to be consumed by Keras model. + """ + dataset = self._dataset + + if preprocess: + preprocess = functools.partial(preprocess, is_training=is_training) + dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) + + if is_training: + if shuffle: + # Shuffle size should be bigger than the batch_size. Otherwise it's only + # shuffling within the batch, which equals to not having shuffle. + buffer_size = 3 * batch_size + # But since we are doing shuffle before repeat, it doesn't make sense to + # shuffle more than total available entries. + # TODO: Investigate if shuffling before / after repeat + # dataset can get a better performance? + # Shuffle after repeat will give a more randomized dataset and mix the + # epoch boundary: https://www.tensorflow.org/guide/data + if self._size: + buffer_size = min(self._size, buffer_size) + dataset = dataset.shuffle(buffer_size=buffer_size) + + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + dataset = dataset.prefetch(tf.data.AUTOTUNE) + # TODO: Consider converting dataset to distributed dataset + # here. + return dataset + + def __len__(self): + """Returns the number of element of the dataset.""" + if self._size is not None: + return self._size + else: + return len(self._dataset) + + def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]: + """Splits dataset into two sub-datasets with the given fraction. + + Primarily used for splitting the data set into training and testing sets. + + Args: + fraction: A float value defines the fraction of the first returned + subdataset in the original data. + + Returns: + The splitted two sub datasets. + """ + return self._split(fraction) + + def _split(self: _DatasetT, fraction: float, + *args) -> Tuple[_DatasetT, _DatasetT]: + """Implementation for `split` method and returns sub-class instances. + + Child DataLoader classes, if requires additional constructor arguments, + should implement their own `split` method by calling `_split` with all + arguments to the constructor. + + Args: + fraction: A float value defines the fraction of the first returned + subdataset in the original data. + *args: additional arguments passed to the sub-class constructor. + + Returns: + The splitted two sub datasets. + """ + assert (fraction > 0 and fraction < 1) + + dataset = self._dataset + + train_size = int(self._size * fraction) + trainset = self.__class__(dataset.take(train_size), train_size, *args) + + test_size = self._size - train_size + testset = self.__class__(dataset.skip(train_size), test_size, *args) + + return trainset, testset diff --git a/mediapipe/model_maker/python/core/data/dataset_test.py b/mediapipe/model_maker/python/core/data/dataset_test.py new file mode 100644 index 000000000..9adff127d --- /dev/null +++ b/mediapipe/model_maker/python/core/data/dataset_test.py @@ -0,0 +1,78 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds +from mediapipe.model_maker.python.core.utils import test_util + + +class DatasetTest(tf.test.TestCase): + + def test_split(self): + dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]) + data = ds.Dataset(dataset, 4) + train_data, test_data = data.split(0.5) + + self.assertLen(train_data, 2) + self.assertIsInstance(train_data, ds.Dataset) + self.assertIsInstance(test_data, ds.Dataset) + for i, elem in enumerate(train_data.gen_tf_dataset()): + self.assertTrue((elem.numpy() == np.array([i, 1])).all()) + + self.assertLen(test_data, 2) + for i, elem in enumerate(test_data.gen_tf_dataset()): + self.assertTrue((elem.numpy() == np.array([i, 0])).all()) + + def test_len(self): + size = 4 + dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]) + data = ds.Dataset(dataset, size) + self.assertLen(data, size) + + def test_gen_tf_dataset(self): + input_dim = 8 + data = test_util.create_dataset( + data_size=2, input_shape=[input_dim], num_classes=2) + + dataset = data.gen_tf_dataset() + self.assertLen(dataset, 2) + for (feature, label) in dataset: + self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all()) + self.assertTrue((tf.shape(label).numpy() == np.array([1])).all()) + + dataset2 = data.gen_tf_dataset(batch_size=2) + self.assertLen(dataset2, 1) + for (feature, label) in dataset2: + self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all()) + self.assertTrue((tf.shape(label).numpy() == np.array([2])).all()) + + dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True) + self.assertEqual(dataset3.cardinality(), 1) + for (feature, label) in dataset3.take(10): + self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all()) + self.assertTrue((tf.shape(label).numpy() == np.array([2])).all()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/data/testdata/BUILD b/mediapipe/model_maker/python/core/data/testdata/BUILD new file mode 100644 index 000000000..54e562d41 --- /dev/null +++ b/mediapipe/model_maker/python/core/data/testdata/BUILD @@ -0,0 +1,30 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) + +package( + default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"], + licenses = ["notice"], # Apache 2.0 +) + +mediapipe_files(srcs = ["test.jpg"]) + +filegroup( + name = "testdata", + srcs = ["test.jpg"], +) diff --git a/mediapipe/model_maker/python/core/hyperparameters.py b/mediapipe/model_maker/python/core/hyperparameters.py new file mode 100644 index 000000000..2a7a8678c --- /dev/null +++ b/mediapipe/model_maker/python/core/hyperparameters.py @@ -0,0 +1,68 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyperparameters for training models. Shared across tasks.""" + +import dataclasses +import tempfile + +from typing import Optional + + +# TODO: Integrate this class into ImageClassifier and other tasks. +@dataclasses.dataclass +class BaseHParams: + """Hyperparameters used for training models. + + A common set of hyperparameters shared by the training jobs of all model + maker tasks. + + Attributes: + learning_rate: The learning rate to use for gradient descent training. + batch_size: Batch size for training. + epochs: Number of training iterations over the dataset. + steps_per_epoch: An optional integer indicate the number of training steps + per epoch. If not set, the training pipeline calculates the default steps + per epoch as the training dataset size devided by batch size. + shuffle: True if the dataset is shuffled before training. + export_dir: The location of the model checkpoint files. + distribution_strategy: A string specifying which Distribution Strategy to + use. Accepted values are 'off', 'one_device', 'mirrored', + 'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case + insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to + use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy + documentation for more details: + https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy. + num_gpus: How many GPUs to use at each worker with the + DistributionStrategies API. The default is -1, which means utilize all + available GPUs. + tpu: The Cloud TPU to use for training. This should be either the name used + when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url. + """ + + # Parameters for train configuration + learning_rate: float + batch_size: int + epochs: int + steps_per_epoch: Optional[int] = None + + # Dataset-related parameters + shuffle: bool = False + + # Parameters for model / checkpoint files + export_dir: str = tempfile.mkdtemp() + + # Parameters for hardware acceleration + distribution_strategy: str = 'off' + num_gpus: int = -1 # default value of -1 means use all available GPUs + tpu: str = '' diff --git a/mediapipe/model_maker/python/core/tasks/BUILD b/mediapipe/model_maker/python/core/tasks/BUILD new file mode 100644 index 000000000..124de621a --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/BUILD @@ -0,0 +1,59 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) + +py_library( + name = "custom_model", + srcs = ["custom_model.py"], + deps = [ + "//mediapipe/model_maker/python/core/data:dataset", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/model_maker/python/core/utils:quantization", + ], +) + +py_test( + name = "custom_model_test", + srcs = ["custom_model_test.py"], + deps = [ + ":custom_model", + "//mediapipe/model_maker/python/core/utils:test_util", + ], +) + +py_library( + name = "classifier", + srcs = ["classifier.py"], + deps = [ + ":custom_model", + "//mediapipe/model_maker/python/core/data:dataset", + ], +) + +py_test( + name = "classifier_test", + srcs = ["classifier_test.py"], + deps = [ + ":classifier", + "//mediapipe/model_maker/python/core/utils:test_util", + ], +) diff --git a/mediapipe/model_maker/python/core/tasks/__init__.py b/mediapipe/model_maker/python/core/tasks/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py new file mode 100644 index 000000000..c327b7ea9 --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -0,0 +1,77 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom classifier.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from typing import Any, List + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.tasks import custom_model + + +class Classifier(custom_model.CustomModel): + """An abstract base class that represents a TensorFlow classifier.""" + + def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool, + full_train: bool): + """Initilizes a classifier with its specifications. + + Args: + model_spec: Specification for the model. + index_by_label: A list that map from index to label class name. + shuffle: Whether the dataset should be shuffled. + full_train: If true, train the model end-to-end including the backbone + and the classification layers on top. Otherwise, only train the top + classification layers. + """ + super(Classifier, self).__init__(model_spec, shuffle) + self._index_by_label = index_by_label + self._full_train = full_train + self._num_classes = len(index_by_label) + + def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any: + """Evaluates the classifier with the provided evaluation dataset. + + Args: + data: Evaluation dataset + batch_size: Number of samples per evaluation step. + + Returns: + The loss value and accuracy. + """ + ds = data.gen_tf_dataset( + batch_size, is_training=False, preprocess=self._preprocess) + return self._model.evaluate(ds) + + def export_labels(self, export_dir: str, label_filename: str = 'labels.txt'): + """Exports classification labels into a label file. + + Args: + export_dir: The directory to save exported files. + label_filename: File name to save labels model. The full export path is + {export_dir}/{label_filename}. + """ + if not tf.io.gfile.exists(export_dir): + tf.io.gfile.makedirs(export_dir) + + label_filepath = os.path.join(export_dir, label_filename) + tf.compat.v1.logging.info('Saving labels in %s', label_filepath) + with tf.io.gfile.GFile(label_filepath, 'w') as f: + f.write('\n'.join(self._index_by_label)) diff --git a/mediapipe/model_maker/python/core/tasks/classifier_test.py b/mediapipe/model_maker/python/core/tasks/classifier_test.py new file mode 100644 index 000000000..1484e8e86 --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/classifier_test.py @@ -0,0 +1,58 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import test_util + + +class MockClassifier(classifier.Classifier): + """A mock class with implementation of abstract methods for testing.""" + + def train(self, train_data, validation_data=None, **kwargs): + pass + + def evaluate(self, data, **kwargs): + pass + + +class ClassifierTest(tf.test.TestCase): + + def setUp(self): + super(ClassifierTest, self).setUp() + index_by_label = ['cat', 'dog'] + self.model = MockClassifier( + model_spec=None, + index_by_label=index_by_label, + shuffle=False, + full_train=False) + self.model.model = test_util.build_model(input_shape=[4], num_classes=2) + + def _check_nonempty_file(self, filepath): + self.assertTrue(os.path.isfile(filepath)) + self.assertGreater(os.path.getsize(filepath), 0) + + def test_export_labels(self): + export_path = os.path.join(self.get_temp_dir(), 'export/') + self.model.export_labels(export_dir=export_path) + self._check_nonempty_file(os.path.join(export_path, 'labels.txt')) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/tasks/custom_model.py b/mediapipe/model_maker/python/core/tasks/custom_model.py new file mode 100644 index 000000000..66d1494db --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/custom_model.py @@ -0,0 +1,83 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Interface to define a custom model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import os +from typing import Any, Callable, Optional + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.core.utils import quantization + + +class CustomModel(abc.ABC): + """The abstract base class that represents a custom TensorFlow model.""" + + def __init__(self, model_spec: Any, shuffle: bool): + """Initializes a custom model with model specs and other parameters. + + Args: + model_spec: Specification for the model. + shuffle: Whether the training data need be shuffled. + """ + self._model_spec = model_spec + self._shuffle = shuffle + self._preprocess = None + self._model = None + + @abc.abstractmethod + def evaluate(self, data: dataset.Dataset, **kwargs): + """Evaluates the model with the provided data.""" + return + + def summary(self): + """Prints a summary of the model.""" + self._model.summary() + + def export_tflite( + self, + export_dir: str, + tflite_filename: str = 'model.tflite', + quantization_config: Optional[quantization.QuantizationConfig] = None, + preprocess: Optional[Callable[..., bool]] = None): + """Converts the model to requested formats. + + Args: + export_dir: The directory to save exported files. + tflite_filename: File name to save tflite model. The full export path is + {export_dir}/{tflite_filename}. + quantization_config: The configuration for model quantization. + preprocess: A callable to preprocess the representative dataset for + quantization. The callable takes three arguments in order: feature, + label, and is_training. + """ + if not tf.io.gfile.exists(export_dir): + tf.io.gfile.makedirs(export_dir) + + tflite_filepath = os.path.join(export_dir, tflite_filename) + # TODO: Populate metadata to the exported TFLite model. + model_util.export_tflite( + model=self._model, + tflite_filepath=tflite_filepath, + quantization_config=quantization_config, + preprocess=preprocess) + tf.compat.v1.logging.info( + 'TensorFlow Lite model exported successfully: %s' % tflite_filepath) diff --git a/mediapipe/model_maker/python/core/tasks/custom_model_test.py b/mediapipe/model_maker/python/core/tasks/custom_model_test.py new file mode 100644 index 000000000..ad77d4ecd --- /dev/null +++ b/mediapipe/model_maker/python/core/tasks/custom_model_test.py @@ -0,0 +1,56 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.tasks import custom_model +from mediapipe.model_maker.python.core.utils import test_util + + +class MockCustomModel(custom_model.CustomModel): + """A mock class with implementation of abstract methods for testing.""" + + def train(self, train_data, validation_data=None, **kwargs): + pass + + def evaluate(self, data, **kwargs): + pass + + +class CustomModelTest(tf.test.TestCase): + + def setUp(self): + super(CustomModelTest, self).setUp() + self._model = MockCustomModel(model_spec=None, shuffle=False) + self._model._model = test_util.build_model(input_shape=[4], num_classes=2) + + def _check_nonempty_file(self, filepath): + self.assertTrue(os.path.isfile(filepath)) + self.assertGreater(os.path.getsize(filepath), 0) + + def test_export_tflite(self): + export_path = os.path.join(self.get_temp_dir(), 'export/') + self._model.export_tflite(export_dir=export_path) + self._check_nonempty_file(os.path.join(export_path, 'model.tflite')) + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD new file mode 100644 index 000000000..a2ec52044 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -0,0 +1,79 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.py"], + deps = [ + ":model_util", + "//mediapipe/model_maker/python/core/data:dataset", + ], +) + +py_library( + name = "model_util", + srcs = ["model_util.py"], + deps = [ + ":quantization", + "//mediapipe/model_maker/python/core/data:dataset", + ], +) + +py_test( + name = "model_util_test", + srcs = ["model_util_test.py"], + deps = [ + ":model_util", + ":quantization", + ":test_util", + ], +) + +py_library( + name = "loss_functions", + srcs = ["loss_functions.py"], + srcs_version = "PY3", +) + +py_test( + name = "loss_functions_test", + srcs = ["loss_functions_test.py"], + deps = [":loss_functions"], +) + +py_library( + name = "quantization", + srcs = ["quantization.py"], + srcs_version = "PY3", + deps = ["//mediapipe/model_maker/python/core/data:dataset"], +) + +py_test( + name = "quantization_test", + srcs = ["quantization_test.py"], + deps = [ + ":quantization", + ":test_util", + ], +) diff --git a/mediapipe/model_maker/python/core/utils/__init__.py b/mediapipe/model_maker/python/core/utils/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/core/utils/loss_functions.py b/mediapipe/model_maker/python/core/utils/loss_functions.py new file mode 100644 index 000000000..5b0aa32bf --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/loss_functions.py @@ -0,0 +1,105 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss function utility library.""" + +from typing import Optional, Sequence + +import tensorflow as tf + + +class FocalLoss(tf.keras.losses.Loss): + """Implementation of focal loss (https://arxiv.org/pdf/1708.02002.pdf). + + This class computes the focal loss between labels and prediction. Focal loss + is a weighted loss function that modulates the standard cross-entropy loss + based on how well the neural network performs on a specific example of a + class. The labels should be provided in a `one_hot` vector representation. + There should be `#classes` floating point values per prediction. + The loss is reduced across all samples using 'sum_over_batch_size' reduction + (see https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction). + + Example usage: + >>> y_true = [[0, 1, 0], [0, 0, 1]] + >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] + >>> gamma = 2 + >>> focal_loss = FocalLoss(gamma) + >>> focal_loss(y_true, y_pred).numpy() + 0.9326 + + >>> # Calling with 'sample_weight'. + >>> focal_loss(y_true, y_pred, sample_weight=tf.constant([0.3, 0.7])).numpy() + 0.6528 + + Usage with the `compile()` API: + ```python + model.compile(optimizer='sgd', loss=FocalLoss(gamma)) + ``` + + """ + + def __init__(self, gamma, class_weight: Optional[Sequence[float]] = None): + """Constructor. + + Args: + gamma: Focal loss gamma, as described in class docs. + class_weight: A weight to apply to the loss, one for each class. The + weight is applied for each input where the ground truth label matches. + """ + super().__init__() + # Used for clipping min/max values of probability values in y_pred to avoid + # NaNs and Infs in computation. + self._epsilon = 1e-7 + # This is a tunable "focusing parameter"; should be >= 0. + # When gamma = 0, the loss returned is the standard categorical + # cross-entropy loss. + self._gamma = gamma + self._class_weight = class_weight + # tf.keras.losses.Loss class implementation requires a Reduction specified + # in self.reduction. To use this reduction, we should use tensorflow's + # compute_weighted_loss function however it is only compatible with v1 of + # Tensorflow: https://www.tensorflow.org/api_docs/python/tf/compat/v1/losses/compute_weighted_loss?hl=en. pylint: disable=line-too-long + # So even though it is specified here, we don't use self.reduction in the + # loss function call. + self.reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE + + def __call__(self, + y_true: tf.Tensor, + y_pred: tf.Tensor, + sample_weight: Optional[tf.Tensor] = None) -> tf.Tensor: + if self._class_weight: + class_weight = tf.convert_to_tensor(self._class_weight, dtype=tf.float32) + label = tf.argmax(y_true, axis=1) + loss_weight = tf.gather(class_weight, label) + else: + loss_weight = tf.ones(tf.shape(y_true)[0]) + y_true = tf.cast(y_true, y_pred.dtype) + y_pred = tf.clip_by_value(y_pred, self._epsilon, 1 - self._epsilon) + batch_size = tf.cast(tf.shape(y_pred)[0], y_pred.dtype) + if sample_weight is None: + sample_weight = tf.constant(1.0) + weight_shape = sample_weight.shape + weight_rank = weight_shape.ndims + y_pred_rank = y_pred.shape.ndims + if y_pred_rank - weight_rank == 1: + sample_weight = tf.expand_dims(sample_weight, [-1]) + elif weight_rank != 0: + raise ValueError(f'Unexpected sample_weights, should be either a scalar' + f'or a vector of batch_size:{batch_size.numpy()}') + ce = -tf.math.log(y_pred) + modulating_factor = tf.math.pow(1 - y_pred, self._gamma) + losses = y_true * modulating_factor * ce * sample_weight + losses = losses * loss_weight[:, tf.newaxis] + # By default, this function uses "sum_over_batch_size" reduction for the + # loss per batch. + return tf.reduce_sum(losses) / batch_size diff --git a/mediapipe/model_maker/python/core/utils/loss_functions_test.py b/mediapipe/model_maker/python/core/utils/loss_functions_test.py new file mode 100644 index 000000000..716c329ef --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/loss_functions_test.py @@ -0,0 +1,103 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import loss_functions + + +class LossFunctionsTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict(testcase_name='no_sample_weight', sample_weight=None), + dict( + testcase_name='with_sample_weight', + sample_weight=tf.constant([0.2, 0.2, 0.3, 0.1, 0.2]))) + def test_focal_loss_gamma_0_is_cross_entropy( + self, sample_weight: Optional[tf.Tensor]): + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, + 0]]) + y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4], + [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) + + tf_cce = tf.keras.losses.CategoricalCrossentropy( + from_logits=False, + reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE) + focal_loss = loss_functions.FocalLoss(gamma=0) + self.assertAllClose( + tf_cce(y_true, y_pred, sample_weight=sample_weight), + focal_loss(y_true, y_pred, sample_weight=sample_weight), 1e-4) + + def test_focal_loss_with_sample_weight(self): + y_true = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, + 0]]) + y_pred = tf.constant([[0.7, 0.1, 0.2], [0.6, 0.3, 0.1], [0.1, 0.5, 0.4], + [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) + + focal_loss = loss_functions.FocalLoss(gamma=0) + + sample_weight = tf.constant([0.2, 0.2, 0.3, 0.1, 0.2]) + + self.assertGreater( + focal_loss(y_true=y_true, y_pred=y_pred), + focal_loss(y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)) + + @parameterized.named_parameters( + dict(testcase_name='gt_0.1', y_pred=tf.constant([0.1, 0.9])), + dict(testcase_name='gt_0.3', y_pred=tf.constant([0.3, 0.7])), + dict(testcase_name='gt_0.5', y_pred=tf.constant([0.5, 0.5])), + dict(testcase_name='gt_0.7', y_pred=tf.constant([0.7, 0.3])), + dict(testcase_name='gt_0.9', y_pred=tf.constant([0.9, 0.1])), + ) + def test_focal_loss_decreases_with_increasing_gamma(self, y_pred: tf.Tensor): + y_true = tf.constant([[1, 0]]) + + focal_loss_gamma_0 = loss_functions.FocalLoss(gamma=0) + loss_gamma_0 = focal_loss_gamma_0(y_true, y_pred) + focal_loss_gamma_0p5 = loss_functions.FocalLoss(gamma=0.5) + loss_gamma_0p5 = focal_loss_gamma_0p5(y_true, y_pred) + focal_loss_gamma_1 = loss_functions.FocalLoss(gamma=1) + loss_gamma_1 = focal_loss_gamma_1(y_true, y_pred) + focal_loss_gamma_2 = loss_functions.FocalLoss(gamma=2) + loss_gamma_2 = focal_loss_gamma_2(y_true, y_pred) + focal_loss_gamma_5 = loss_functions.FocalLoss(gamma=5) + loss_gamma_5 = focal_loss_gamma_5(y_true, y_pred) + + self.assertGreater(loss_gamma_0, loss_gamma_0p5) + self.assertGreater(loss_gamma_0p5, loss_gamma_1) + self.assertGreater(loss_gamma_1, loss_gamma_2) + self.assertGreater(loss_gamma_2, loss_gamma_5) + + @parameterized.named_parameters( + dict(testcase_name='index_0', true_class=0), + dict(testcase_name='index_1', true_class=1), + dict(testcase_name='index_2', true_class=2), + ) + def test_focal_loss_class_weight_is_applied(self, true_class: int): + class_weight = [1.0, 3.0, 10.0] + y_pred = tf.constant([[1.0, 1.0, 1.0]]) / 3.0 + y_true = tf.one_hot(true_class, depth=3)[tf.newaxis, :] + expected_loss = -math.log(1.0 / 3.0) * class_weight[true_class] + + loss_fn = loss_functions.FocalLoss(gamma=0, class_weight=class_weight) + loss = loss_fn(y_true, y_pred) + self.assertNear(loss, expected_loss, 1e-4) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py new file mode 100644 index 000000000..e1228eb6e --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -0,0 +1,272 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for keras models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +from typing import Any, Callable, Dict, List, Optional, Text, Tuple, Union + +# Dependency imports + +import numpy as np +import tensorflow as tf + +# resources dependency +from mediapipe.model_maker.python.core.data import dataset +from mediapipe.model_maker.python.core.utils import quantization + +DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 +ESTIMITED_STEPS_PER_EPOCH = 1000 + + +def load_keras_model(model_path: str, + compile_on_load: bool = False) -> tf.keras.Model: + """Loads a tensorflow Keras model from file and returns the Keras model. + + Args: + model_path: Relative path to a directory containing model data, such as + /saved_model/. + compile_on_load: Whether the model should be compiled while loading. If + False, the model returned has to be compiled with the appropriate loss + function and custom metrics before running for inference on a test + dataset. + + Returns: + A tensorflow Keras model. + """ + # Extract the file path before mediapipe/ as the `base_dir`. By joining it + # with the `model_path` which defines the relative path under mediapipe/, it + # yields to the aboslution path of the model files directory. + cwd = os.path.dirname(__file__) + base_dir = cwd[:cwd.rfind('mediapipe')] + absolute_path = os.path.join(base_dir, model_path) + return tf.keras.models.load_model( + absolute_path, custom_objects={'tf': tf}, compile=compile_on_load) + + +def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, + batch_size: Optional[int] = None, + train_data: Optional[dataset.Dataset] = None) -> int: + """Gets the estimated training steps per epoch. + + 1. If `steps_per_epoch` is set, returns `steps_per_epoch` directly. + 2. Else if we can get the length of training data successfully, returns + `train_data_length // batch_size`. + + Args: + steps_per_epoch: int, training steps per epoch. + batch_size: int, batch size. + train_data: training data. + + Returns: + Estimated training steps per epoch. + + Raises: + ValueError: if both steps_per_epoch and train_data are not set. + """ + if steps_per_epoch is not None: + # steps_per_epoch is set by users manually. + return steps_per_epoch + else: + if train_data is None: + raise ValueError('Input train_data cannot be None.') + # Gets the steps by the length of the training data. + return len(train_data) // batch_size + + +def export_tflite( + model: tf.keras.Model, + tflite_filepath: str, + quantization_config: Optional[quantization.QuantizationConfig] = None, + supported_ops: Tuple[tf.lite.OpsSet, + ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), + preprocess: Optional[Callable[..., bool]] = None): + """Converts the model to tflite format and saves it. + + Args: + model: model to be converted to tflite. + tflite_filepath: File path to save tflite model. + quantization_config: Configuration for post-training quantization. + supported_ops: A list of supported ops in the converted TFLite file. + preprocess: A callable to preprocess the representative dataset for + quantization. The callable takes three arguments in order: feature, label, + and is_training. + """ + if tflite_filepath is None: + raise ValueError( + "TFLite filepath couldn't be None when exporting to tflite.") + + with tempfile.TemporaryDirectory() as temp_dir: + save_path = os.path.join(temp_dir, 'saved_model') + model.save(save_path, include_optimizer=False, save_format='tf') + converter = tf.lite.TFLiteConverter.from_saved_model(save_path) + + if quantization_config: + converter = quantization_config.set_converter_with_quantization( + converter, preprocess=preprocess) + + converter.target_spec.supported_ops = supported_ops + tflite_model = converter.convert() + + with tf.io.gfile.GFile(tflite_filepath, 'wb') as f: + f.write(tflite_model) + + +class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule): + """Applies a warmup schedule on a given learning rate decay schedule.""" + + def __init__(self, + initial_learning_rate: float, + decay_schedule_fn: Callable[[Any], Any], + warmup_steps: int, + name: Optional[str] = None): + """Initializes a new instance of the `WarmUp` class. + + Args: + initial_learning_rate: learning rate after the warmup. + decay_schedule_fn: A function maps step to learning rate. Will be applied + for values of step larger than 'warmup_steps'. + warmup_steps: Number of steps to do warmup for. + name: TF namescope under which to perform the learning rate calculation. + """ + super(WarmUp, self).__init__() + self.initial_learning_rate = initial_learning_rate + self.warmup_steps = warmup_steps + self.decay_schedule_fn = decay_schedule_fn + self.name = name + + def __call__(self, step: Union[int, tf.Tensor]) -> tf.Tensor: + with tf.name_scope(self.name or 'WarmUp') as name: + # Implements linear warmup. i.e., if global_step < warmup_steps, the + # learning rate will be `global_step/num_warmup_steps * init_lr`. + global_step_float = tf.cast(step, tf.float32) + warmup_steps_float = tf.cast(self.warmup_steps, tf.float32) + warmup_percent_done = global_step_float / warmup_steps_float + warmup_learning_rate = self.initial_learning_rate * warmup_percent_done + return tf.cond( + global_step_float < warmup_steps_float, + lambda: warmup_learning_rate, + lambda: self.decay_schedule_fn(step), + name=name) + + def get_config(self) -> Dict[Text, Any]: + return { + 'initial_learning_rate': self.initial_learning_rate, + 'decay_schedule_fn': self.decay_schedule_fn, + 'warmup_steps': self.warmup_steps, + 'name': self.name + } + + +class LiteRunner(object): + """A runner to do inference with the TFLite model.""" + + def __init__(self, tflite_filepath: str): + """Initializes Lite runner with tflite model file. + + Args: + tflite_filepath: File path to the TFLite model. + """ + with tf.io.gfile.GFile(tflite_filepath, 'rb') as f: + tflite_model = f.read() + self.interpreter = tf.lite.Interpreter(model_content=tflite_model) + self.interpreter.allocate_tensors() + self.input_details = self.interpreter.get_input_details() + self.output_details = self.interpreter.get_output_details() + + def run( + self, input_tensors: Union[List[tf.Tensor], Dict[str, tf.Tensor]] + ) -> Union[List[tf.Tensor], tf.Tensor]: + """Runs inference with the TFLite model. + + Args: + input_tensors: List / Dict of the input tensors of the TFLite model. The + order should be the same as the keras model if it's a list. It also + accepts tensor directly if the model has only 1 input. + + Returns: + List of the output tensors for multi-output models, otherwise just + the output tensor. The order should be the same as the keras model. + """ + + if not isinstance(input_tensors, list) and not isinstance( + input_tensors, dict): + input_tensors = [input_tensors] + + interpreter = self.interpreter + + # Reshape inputs + for i, input_detail in enumerate(self.input_details): + input_tensor = _get_input_tensor( + input_tensors=input_tensors, + input_details=self.input_details, + index=i) + interpreter.resize_tensor_input( + input_index=input_detail['index'], tensor_size=input_tensor.shape) + interpreter.allocate_tensors() + + # Feed input to the interpreter + for i, input_detail in enumerate(self.input_details): + input_tensor = _get_input_tensor( + input_tensors=input_tensors, + input_details=self.input_details, + index=i) + if input_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): + # Quantize the input + scale, zero_point = input_detail['quantization'] + input_tensor = input_tensor / scale + zero_point + input_tensor = np.array(input_tensor, dtype=input_detail['dtype']) + interpreter.set_tensor(input_detail['index'], input_tensor) + + interpreter.invoke() + + output_tensors = [] + for output_detail in self.output_details: + output_tensor = interpreter.get_tensor(output_detail['index']) + if output_detail['quantization'] != (DEFAULT_SCALE, DEFAULT_ZERO_POINT): + # Dequantize the output + scale, zero_point = output_detail['quantization'] + output_tensor = output_tensor.astype(np.float32) + output_tensor = (output_tensor - zero_point) * scale + output_tensors.append(output_tensor) + + if len(output_tensors) == 1: + return output_tensors[0] + return output_tensors + + +def get_lite_runner(tflite_filepath: str) -> 'LiteRunner': + """Returns a `LiteRunner` from file path to TFLite model.""" + lite_runner = LiteRunner(tflite_filepath) + return lite_runner + + +def _get_input_tensor(input_tensors: Union[List[tf.Tensor], Dict[str, + tf.Tensor]], + input_details: Dict[str, Any], index: int) -> tf.Tensor: + """Returns input tensor in `input_tensors` that maps `input_detail[i]`.""" + if isinstance(input_tensors, dict): + # Gets the mapped input tensor. + input_detail = input_details + for input_tensor_name, input_tensor in input_tensors.items(): + if input_tensor_name in input_detail['name']: + return input_tensor + raise ValueError('Input tensors don\'t contains a tensor that mapped the ' + 'input detail %s' % str(input_detail)) + else: + return input_tensors[index] diff --git a/mediapipe/model_maker/python/core/utils/model_util_test.py b/mediapipe/model_maker/python/core/utils/model_util_test.py new file mode 100644 index 000000000..35b52eb75 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/model_util_test.py @@ -0,0 +1,142 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.core.utils import test_util + + +class ModelUtilTest(tf.test.TestCase, parameterized.TestCase): + + def test_load_model(self): + input_dim = 4 + model = test_util.build_model(input_shape=[input_dim], num_classes=2) + saved_model_path = os.path.join(self.get_temp_dir(), 'saved_model') + model.save(saved_model_path) + loaded_model = model_util.load_keras_model(saved_model_path) + + input_tensors = test_util.create_random_sample(size=[1, input_dim]) + model_output = model.predict_on_batch(input_tensors) + loaded_model_output = loaded_model.predict_on_batch(input_tensors) + self.assertTrue((model_output == loaded_model_output).all()) + + @parameterized.named_parameters( + dict( + testcase_name='input_only_steps_per_epoch', + steps_per_epoch=1000, + batch_size=None, + train_data=None, + expected_steps_per_epoch=1000), + dict( + testcase_name='input_steps_per_epoch_and_batch_size', + steps_per_epoch=1000, + batch_size=32, + train_data=None, + expected_steps_per_epoch=1000), + dict( + testcase_name='input_steps_per_epoch_batch_size_and_train_data', + steps_per_epoch=1000, + batch_size=32, + train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]), + expected_steps_per_epoch=1000), + dict( + testcase_name='input_batch_size_and_train_data', + steps_per_epoch=None, + batch_size=2, + train_data=tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], + [1, 0]]), + expected_steps_per_epoch=2)) + def test_get_steps_per_epoch(self, steps_per_epoch, batch_size, train_data, + expected_steps_per_epoch): + estimated_steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=steps_per_epoch, + batch_size=batch_size, + train_data=train_data) + self.assertEqual(estimated_steps_per_epoch, expected_steps_per_epoch) + + def test_get_steps_per_epoch_raise_value_error(self): + with self.assertRaises(ValueError): + model_util.get_steps_per_epoch( + steps_per_epoch=None, batch_size=16, train_data=None) + + def test_warmup(self): + init_lr = 0.1 + warmup_steps = 1000 + num_decay_steps = 100 + learning_rate_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate=init_lr, decay_steps=num_decay_steps) + warmup_object = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate_fn, + warmup_steps=1000, + name='test') + self.assertEqual( + warmup_object.get_config(), { + 'initial_learning_rate': init_lr, + 'decay_schedule_fn': learning_rate_fn, + 'warmup_steps': warmup_steps, + 'name': 'test' + }) + + def test_export_tflite(self): + input_dim = 4 + model = test_util.build_model(input_shape=[input_dim], num_classes=2) + tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite') + model_util.export_tflite(model, tflite_file) + test_util.test_tflite( + keras_model=model, tflite_file=tflite_file, size=[1, input_dim]) + + @parameterized.named_parameters( + dict( + testcase_name='dynamic_quantize', + config=quantization.QuantizationConfig.for_dynamic(), + model_size=1288), + dict( + testcase_name='int8_quantize', + config=quantization.QuantizationConfig.for_int8( + representative_data=test_util.create_dataset( + data_size=10, input_shape=[16], num_classes=3)), + model_size=1832), + dict( + testcase_name='float16_quantize', + config=quantization.QuantizationConfig.for_float16(), + model_size=1468)) + def test_export_tflite_quantized(self, config, model_size): + input_dim = 16 + num_classes = 2 + max_input_value = 5 + model = test_util.build_model( + input_shape=[input_dim], num_classes=num_classes) + tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite') + + model_util.export_tflite( + model=model, tflite_filepath=tflite_file, quantization_config=config) + self.assertTrue( + test_util.test_tflite( + keras_model=model, + tflite_file=tflite_file, + size=[1, input_dim], + high=max_input_value, + atol=1e-00)) + self.assertNear(os.path.getsize(tflite_file), model_size, 300) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/quantization.py b/mediapipe/model_maker/python/core/utils/quantization.py new file mode 100644 index 000000000..a1a38cc64 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/quantization.py @@ -0,0 +1,213 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Libraries for post-training quantization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import Any, Callable, List, Optional, Union + +# Dependency imports + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds + +DEFAULT_QUANTIZATION_STEPS = 500 + + +def _get_representative_dataset_generator(dataset: tf.data.Dataset, + num_steps: int) -> Callable[[], Any]: + """Gets a representative dataset generator for post-training quantization. + + The generator is to provide a small dataset to calibrate or estimate the + range, i.e, (min, max) of all floating-point arrays in the model for + quantization. Usually, this is a small subset of a few hundred samples + randomly chosen, in no particular order, from the training or evaluation + dataset. See tf.lite.RepresentativeDataset for more details. + + Args: + dataset: Input dataset for extracting representative sub dataset. + num_steps: The number of quantization steps which also reflects the size of + the representative dataset. + + Returns: + A representative dataset generator. + """ + + def representative_dataset_gen(): + """Generates representative dataset for quantization.""" + for data, _ in dataset.take(num_steps): + yield [data] + + return representative_dataset_gen + + +class QuantizationConfig(object): + """Configuration for post-training quantization. + + Refer to + https://www.tensorflow.org/lite/performance/post_training_quantization + for different post-training quantization options. + """ + + def __init__( + self, + optimizations: Optional[Union[tf.lite.Optimize, + List[tf.lite.Optimize]]] = None, + representative_data: Optional[ds.Dataset] = None, + quantization_steps: Optional[int] = None, + inference_input_type: Optional[tf.dtypes.DType] = None, + inference_output_type: Optional[tf.dtypes.DType] = None, + supported_ops: Optional[Union[tf.lite.OpsSet, + List[tf.lite.OpsSet]]] = None, + supported_types: Optional[Union[tf.dtypes.DType, + List[tf.dtypes.DType]]] = None, + experimental_new_quantizer: bool = False, + ): + """Constructs QuantizationConfig. + + Args: + optimizations: A list of optimizations to apply when converting the model. + If not set, use `[Optimize.DEFAULT]` by default. + representative_data: A representative ds.Dataset for post-training + quantization. + quantization_steps: Number of post-training quantization calibration steps + to run (default to DEFAULT_QUANTIZATION_STEPS). + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays. Defaults to None. If set, must be + be `{tf.float32, tf.uint8, tf.int8}`. + inference_output_type: Target data type of real-number output arrays. + Allows for a different type for output arrays. Defaults to None. If set, + must be `{tf.float32, tf.uint8, tf.int8}`. + supported_ops: Set of OpsSet options supported by the device. Used to Set + converter.target_spec.supported_ops. + supported_types: List of types for constant values on the target device. + Supported values are types exported by lite.constants. Frequently, an + optimization choice is driven by the most compact (i.e. smallest) type + in this list (default [constants.FLOAT]). + experimental_new_quantizer: Whether to enable experimental new quantizer. + + Raises: + ValueError: if inference_input_type or inference_output_type are set but + not in {tf.float32, tf.uint8, tf.int8}. + """ + if inference_input_type is not None and inference_input_type not in { + tf.float32, tf.uint8, tf.int8 + }: + raise ValueError('Unsupported inference_input_type %s' % + inference_input_type) + if inference_output_type is not None and inference_output_type not in { + tf.float32, tf.uint8, tf.int8 + }: + raise ValueError('Unsupported inference_output_type %s' % + inference_output_type) + + if optimizations is None: + optimizations = [tf.lite.Optimize.DEFAULT] + if not isinstance(optimizations, list): + optimizations = [optimizations] + self.optimizations = optimizations + + self.representative_data = representative_data + if self.representative_data is not None and quantization_steps is None: + quantization_steps = DEFAULT_QUANTIZATION_STEPS + self.quantization_steps = quantization_steps + + self.inference_input_type = inference_input_type + self.inference_output_type = inference_output_type + + if supported_ops is not None and not isinstance(supported_ops, list): + supported_ops = [supported_ops] + self.supported_ops = supported_ops + + if supported_types is not None and not isinstance(supported_types, list): + supported_types = [supported_types] + self.supported_types = supported_types + + self.experimental_new_quantizer = experimental_new_quantizer + + @classmethod + def for_dynamic(cls) -> 'QuantizationConfig': + """Creates configuration for dynamic range quantization.""" + return QuantizationConfig() + + @classmethod + def for_int8( + cls, + representative_data: ds.Dataset, + quantization_steps: int = DEFAULT_QUANTIZATION_STEPS, + inference_input_type: tf.dtypes.DType = tf.uint8, + inference_output_type: tf.dtypes.DType = tf.uint8, + supported_ops: tf.lite.OpsSet = tf.lite.OpsSet.TFLITE_BUILTINS_INT8 + ) -> 'QuantizationConfig': + """Creates configuration for full integer quantization. + + Args: + representative_data: Representative data used for post-training + quantization. + quantization_steps: Number of post-training quantization calibration steps + to run. + inference_input_type: Target data type of real-number input arrays. + inference_output_type: Target data type of real-number output arrays. + supported_ops: Set of `tf.lite.OpsSet` options, where each option + represents a set of operators supported by the target device. + + Returns: + QuantizationConfig. + """ + return QuantizationConfig( + representative_data=representative_data, + quantization_steps=quantization_steps, + inference_input_type=inference_input_type, + inference_output_type=inference_output_type, + supported_ops=supported_ops) + + @classmethod + def for_float16(cls) -> 'QuantizationConfig': + """Creates configuration for float16 quantization.""" + return QuantizationConfig(supported_types=[tf.float16]) + + def set_converter_with_quantization(self, converter: tf.lite.TFLiteConverter, + **kwargs: Any) -> tf.lite.TFLiteConverter: + """Sets input TFLite converter with quantization configurations. + + Args: + converter: input tf.lite.TFLiteConverter. + **kwargs: arguments used by ds.Dataset.gen_tf_dataset. + + Returns: + tf.lite.TFLiteConverter with quantization configurations. + """ + converter.optimizations = self.optimizations + + if self.representative_data is not None: + tf_ds = self.representative_data.gen_tf_dataset( + batch_size=1, is_training=False, **kwargs) + converter.representative_dataset = tf.lite.RepresentativeDataset( + _get_representative_dataset_generator(tf_ds, self.quantization_steps)) + + if self.inference_input_type: + converter.inference_input_type = self.inference_input_type + if self.inference_output_type: + converter.inference_output_type = self.inference_output_type + if self.supported_ops: + converter.target_spec.supported_ops = self.supported_ops + if self.supported_types: + converter.target_spec.supported_types = self.supported_types + + if self.experimental_new_quantizer is not None: + converter.experimental_new_quantizer = self.experimental_new_quantizer + return converter diff --git a/mediapipe/model_maker/python/core/utils/quantization_test.py b/mediapipe/model_maker/python/core/utils/quantization_test.py new file mode 100644 index 000000000..9d27d34ac --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/quantization_test.py @@ -0,0 +1,108 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.core.utils import test_util + + +class QuantizationTest(tf.test.TestCase, parameterized.TestCase): + + def test_create_dynamic_quantization_config(self): + config = quantization.QuantizationConfig.for_dynamic() + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertIsNone(config.representative_data) + self.assertIsNone(config.inference_input_type) + self.assertIsNone(config.inference_output_type) + self.assertIsNone(config.supported_ops) + self.assertIsNone(config.supported_types) + self.assertFalse(config.experimental_new_quantizer) + + def test_create_int8_quantization_config(self): + representative_data = test_util.create_dataset( + data_size=10, input_shape=[4], num_classes=3) + config = quantization.QuantizationConfig.for_int8( + representative_data=representative_data) + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertEqual(config.inference_input_type, tf.uint8) + self.assertEqual(config.inference_output_type, tf.uint8) + self.assertEqual(config.supported_ops, + [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]) + self.assertFalse(config.experimental_new_quantizer) + + def test_set_converter_with_quantization_from_int8_config(self): + representative_data = test_util.create_dataset( + data_size=10, input_shape=[4], num_classes=3) + config = quantization.QuantizationConfig.for_int8( + representative_data=representative_data) + model = test_util.build_model(input_shape=[4], num_classes=3) + saved_model_dir = self.get_temp_dir() + model.save(saved_model_dir) + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + converter = config.set_converter_with_quantization(converter=converter) + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertEqual(config.inference_input_type, tf.uint8) + self.assertEqual(config.inference_output_type, tf.uint8) + self.assertEqual(config.supported_ops, + [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]) + tflite_model = converter.convert() + interpreter = tf.lite.Interpreter(model_content=tflite_model) + self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8) + self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8) + + def test_create_float16_quantization_config(self): + config = quantization.QuantizationConfig.for_float16() + self.assertEqual(config.optimizations, [tf.lite.Optimize.DEFAULT]) + self.assertIsNone(config.representative_data) + self.assertIsNone(config.inference_input_type) + self.assertIsNone(config.inference_output_type) + self.assertIsNone(config.supported_ops) + self.assertEqual(config.supported_types, [tf.float16]) + self.assertFalse(config.experimental_new_quantizer) + + def test_set_converter_with_quantization_from_float16_config(self): + config = quantization.QuantizationConfig.for_float16() + model = test_util.build_model(input_shape=[4], num_classes=3) + saved_model_dir = self.get_temp_dir() + model.save(saved_model_dir) + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) + converter = config.set_converter_with_quantization(converter=converter) + self.assertEqual(config.supported_types, [tf.float16]) + tflite_model = converter.convert() + interpreter = tf.lite.Interpreter(model_content=tflite_model) + # The input and output are expected to be set to float32 by default. + self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32) + self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32) + + @parameterized.named_parameters( + dict( + testcase_name='invalid_inference_input_type', + inference_input_type=tf.uint8, + inference_output_type=tf.int64), + dict( + testcase_name='invalid_inference_output_type', + inference_input_type=tf.int64, + inference_output_type=tf.float32)) + def test_create_quantization_config_failure(self, inference_input_type, + inference_output_type): + with self.assertRaises(ValueError): + _ = quantization.QuantizationConfig( + inference_input_type=inference_input_type, + inference_output_type=inference_output_type) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/core/utils/test_util.py b/mediapipe/model_maker/python/core/utils/test_util.py new file mode 100644 index 000000000..b402d3793 --- /dev/null +++ b/mediapipe/model_maker/python/core/utils/test_util.py @@ -0,0 +1,123 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test utilities for model maker.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import List, Union + +# Dependency imports + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import dataset as ds +from mediapipe.model_maker.python.core.utils import model_util + + +def create_dataset(data_size: int, + input_shape: List[int], + num_classes: int, + max_input_value: int = 1000) -> ds.Dataset: + """Creates and returns a simple `Dataset` object for test.""" + features = tf.random.uniform( + shape=[data_size] + input_shape, + minval=0, + maxval=max_input_value, + dtype=tf.float32) + + labels = tf.random.uniform( + shape=[data_size], minval=0, maxval=num_classes, dtype=tf.int32) + + tf_dataset = tf.data.Dataset.from_tensor_slices((features, labels)) + dataset = ds.Dataset(tf_dataset, data_size) + return dataset + + +def create_random_sample(size: Union[int, List[int]], + low: float = 0, + high: float = 1) -> np.ndarray: + """Creates and returns a random sample with floating point values. + + Args: + size: Size of the output multi-dimensional array. + low: Lower boundary of the output values. + high: Higher boundary of the output values. + + Returns: + 1D array if the size is scalar. Otherwise, N-D array whose dimension equals + input size. + """ + np.random.seed(0) + return np.random.uniform(low=low, high=high, size=size).astype(np.float32) + + +def build_model(input_shape: List[int], num_classes: int) -> tf.keras.Model: + """Builds a simple Keras model for test.""" + inputs = tf.keras.layers.Input(shape=input_shape) + if len(input_shape) == 3: # Image inputs. + outputs = tf.keras.layers.GlobalAveragePooling2D()(inputs) + outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(outputs) + elif len(input_shape) == 1: # Text inputs. + outputs = tf.keras.layers.Dense(num_classes, activation="softmax")(inputs) + else: + raise ValueError("Model inputs should be 2D tensor or 4D tensor.") + + model = tf.keras.Model(inputs=inputs, outputs=outputs) + return model + + +def is_same_output(tflite_file: str, + keras_model: tf.keras.Model, + input_tensors: Union[List[tf.Tensor], tf.Tensor], + atol: float = 1e-04) -> bool: + """Returns if the output of TFLite model and keras model are identical.""" + # Gets output from lite model. + lite_runner = model_util.get_lite_runner(tflite_file) + lite_output = lite_runner.run(input_tensors) + + # Gets output from keras model. + keras_output = keras_model.predict_on_batch(input_tensors) + + return np.allclose(lite_output, keras_output, atol=atol) + + +def test_tflite(keras_model: tf.keras.Model, + tflite_file: str, + size: Union[int, List[int]], + high: float = 1, + atol: float = 1e-04) -> bool: + """Verifies if the output of TFLite model and TF Keras model are identical. + + Args: + keras_model: Input TensorFlow Keras model. + tflite_file: Input TFLite model file. + size: Size of the input tesnor. + high: Higher boundary of the values in input tensors. + atol: Absolute tolerance of the difference between the outputs of Keras + model and TFLite model. + + Returns: + True if the output of TFLite model and TF Keras model are identical. + Otherwise, False. + """ + random_input = create_random_sample(size=size, high=high) + random_input = tf.convert_to_tensor(random_input) + + return is_same_output( + tflite_file=tflite_file, + keras_model=keras_model, + input_tensors=random_input, + atol=atol) diff --git a/mediapipe/model_maker/python/vision/BUILD b/mediapipe/model_maker/python/vision/BUILD new file mode 100644 index 000000000..10aef8c33 --- /dev/null +++ b/mediapipe/model_maker/python/vision/BUILD @@ -0,0 +1,19 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +licenses(["notice"]) diff --git a/mediapipe/model_maker/python/vision/__init__.py b/mediapipe/model_maker/python/vision/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/vision/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/vision/core/BUILD b/mediapipe/model_maker/python/vision/core/BUILD new file mode 100644 index 000000000..0b15a0276 --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/BUILD @@ -0,0 +1,33 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python strict test compatibility macro. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "image_preprocessing", + srcs = ["image_preprocessing.py"], +) + +py_test( + name = "image_preprocessing_test", + srcs = ["image_preprocessing_test.py"], + deps = [":image_preprocessing"], +) diff --git a/mediapipe/model_maker/python/vision/core/__init__.py b/mediapipe/model_maker/python/vision/core/__init__.py new file mode 100644 index 000000000..7ca2f9216 --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mediapipe/model_maker/python/vision/core/image_preprocessing.py b/mediapipe/model_maker/python/vision/core/image_preprocessing.py new file mode 100644 index 000000000..104ccd9ca --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/image_preprocessing.py @@ -0,0 +1,224 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ImageNet preprocessing.""" + +import tensorflow as tf + +IMAGE_SIZE = 224 +CROP_PADDING = 32 + + +class Preprocessor(object): + """Preprocessor for image classification.""" + + def __init__(self, + input_shape, + num_classes, + mean_rgb, + stddev_rgb, + use_augmentation=False): + self.input_shape = input_shape + self.num_classes = num_classes + self.mean_rgb = mean_rgb + self.stddev_rgb = stddev_rgb + self.use_augmentation = use_augmentation + + def __call__(self, image, label, is_training=True): + if self.use_augmentation: + return self._preprocess_with_augmentation(image, label, is_training) + return self._preprocess_without_augmentation(image, label) + + def _preprocess_with_augmentation(self, image, label, is_training): + """Image preprocessing method with data augmentation.""" + image_size = self.input_shape[0] + if is_training: + image = preprocess_for_train(image, image_size) + else: + image = preprocess_for_eval(image, image_size) + + image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype) + image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) + + label = tf.one_hot(label, depth=self.num_classes) + return image, label + + # TODO: Changes to preprocess to support batch input. + def _preprocess_without_augmentation(self, image, label): + """Image preprocessing method without data augmentation.""" + image = tf.cast(image, tf.float32) + + image -= tf.constant(self.mean_rgb, shape=[1, 1, 3], dtype=image.dtype) + image /= tf.constant(self.stddev_rgb, shape=[1, 1, 3], dtype=image.dtype) + + image = tf.compat.v1.image.resize(image, self.input_shape) + label = tf.one_hot(label, depth=self.num_classes) + return image, label + + +def _distorted_bounding_box_crop(image, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=100): + """Generates cropped_image using one of the bboxes randomly distorted. + + See `tf.image.sample_distorted_bounding_box` for more documentation. + + Args: + image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of + shape [height, width, channels]. + bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` where + each coordinate is [0, 1) and the coordinates are arranged as `[ymin, + xmin, ymax, xmax]`. If num_boxes is 0 then use the whole image. + min_object_covered: An optional `float`. Defaults to `0.1`. The cropped area + of the image must contain at least this fraction of any bounding box + supplied. + aspect_ratio_range: An optional list of `float`s. The cropped area of the + image must have an aspect ratio = width / height within this range. + area_range: An optional list of `float`s. The cropped area of the image must + contain a fraction of the supplied image within in this range. + max_attempts: An optional `int`. Number of attempts at generating a cropped + region of the image of the specified constraints. After `max_attempts` + failures, return the entire image. + + Returns: + A cropped image `Tensor` + """ + with tf.name_scope('distorted_bounding_box_crop'): + shape = tf.shape(image) + sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( + shape, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True) + bbox_begin, bbox_size, _ = sample_distorted_bounding_box + + # Crop the image to the specified bounding box. + offset_y, offset_x, _ = tf.unstack(bbox_begin) + target_height, target_width, _ = tf.unstack(bbox_size) + image = tf.image.crop_to_bounding_box(image, offset_y, offset_x, + target_height, target_width) + + return image + + +def _at_least_x_are_equal(a, b, x): + """At least `x` of `a` and `b` `Tensors` are equal.""" + match = tf.equal(a, b) + match = tf.cast(match, tf.int32) + return tf.greater_equal(tf.reduce_sum(match), x) + + +def _resize_image(image, image_size, method=None): + if method is not None: + tf.compat.v1.logging.info('Use customized resize method {}'.format(method)) + return tf.compat.v1.image.resize([image], [image_size, image_size], + method)[0] + tf.compat.v1.logging.info('Use default resize_bicubic.') + return tf.compat.v1.image.resize_bicubic([image], [image_size, image_size])[0] + + +def _decode_and_random_crop(original_image, image_size, resize_method=None): + """Makes a random crop of image_size.""" + bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) + image = _distorted_bounding_box_crop( + original_image, + bbox, + min_object_covered=0.1, + aspect_ratio_range=(3. / 4, 4. / 3.), + area_range=(0.08, 1.0), + max_attempts=10) + original_shape = tf.shape(original_image) + bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) + + image = tf.cond(bad, + lambda: _decode_and_center_crop(original_image, image_size), + lambda: _resize_image(image, image_size, resize_method)) + + return image + + +def _decode_and_center_crop(image, image_size, resize_method=None): + """Crops to center of image with padding then scales image_size.""" + shape = tf.shape(image) + image_height = shape[0] + image_width = shape[1] + + padded_center_crop_size = tf.cast( + ((image_size / (image_size + CROP_PADDING)) * + tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) + + offset_height = ((image_height - padded_center_crop_size) + 1) // 2 + offset_width = ((image_width - padded_center_crop_size) + 1) // 2 + image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, + padded_center_crop_size, + padded_center_crop_size) + image = _resize_image(image, image_size, resize_method) + return image + + +def _flip(image): + """Random horizontal image flip.""" + image = tf.image.random_flip_left_right(image) + return image + + +def preprocess_for_train( + image: tf.Tensor, + image_size: int = IMAGE_SIZE, + resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor: + """Preprocesses the given image for evaluation. + + Args: + image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of + shape [height, width, channels]. + image_size: image size. + resize_method: resize method. If none, use bicubic. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_random_crop(image, image_size, resize_method) + image = _flip(image) + image = tf.reshape(image, [image_size, image_size, 3]) + + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + + return image + + +def preprocess_for_eval( + image: tf.Tensor, + image_size: int = IMAGE_SIZE, + resize_method: str = tf.image.ResizeMethod.BILINEAR) -> tf.Tensor: + """Preprocesses the given image for evaluation. + + Args: + image: 4-D Tensor of shape [batch, height, width, channels] or 3-D Tensor of + shape [height, width, channels]. + image_size: image size. + resize_method: if None, use bicubic. + + Returns: + A preprocessed image `Tensor`. + """ + image = _decode_and_center_crop(image, image_size, resize_method) + image = tf.reshape(image, [image_size, image_size, 3]) + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + return image diff --git a/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py b/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py new file mode 100644 index 000000000..0594b4376 --- /dev/null +++ b/mediapipe/model_maker/python/vision/core/image_preprocessing_test.py @@ -0,0 +1,80 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.vision.core import image_preprocessing + + +def _get_preprocessed_image(preprocessor, is_training=False): + image_placeholder = tf.compat.v1.placeholder(tf.uint8, [24, 24, 3]) + label_placeholder = tf.compat.v1.placeholder(tf.int32, [1]) + image_tensor, _ = preprocessor(image_placeholder, label_placeholder, + is_training) + + with tf.compat.v1.Session() as sess: + input_image = np.arange(24 * 24 * 3, dtype=np.uint8).reshape([24, 24, 3]) + image = sess.run( + image_tensor, + feed_dict={ + image_placeholder: input_image, + label_placeholder: [0] + }) + return image + + +class PreprocessorTest(tf.test.TestCase): + + def test_preprocess_without_augmentation(self): + preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2], + num_classes=2, + mean_rgb=[0.0], + stddev_rgb=[255.0], + use_augmentation=False) + actual_image = np.array([[[0., 0.00392157, 0.00784314], + [0.14117648, 0.14509805, 0.14901961]], + [[0.37647063, 0.3803922, 0.38431376], + [0.5176471, 0.52156866, 0.5254902]]]) + + image = _get_preprocessed_image(preprocessor) + self.assertTrue(np.allclose(image, actual_image, atol=1e-05)) + + def test_preprocess_with_augmentation(self): + image_preprocessing.CROP_PADDING = 1 + preprocessor = image_preprocessing.Preprocessor(input_shape=[2, 2], + num_classes=2, + mean_rgb=[0.0], + stddev_rgb=[255.0], + use_augmentation=True) + # Tests validation image. + actual_eval_image = np.array([[[0.17254902, 0.1764706, 0.18039216], + [0.26666668, 0.27058825, 0.27450982]], + [[0.42352945, 0.427451, 0.43137258], + [0.5176471, 0.52156866, 0.5254902]]]) + + image = _get_preprocessed_image(preprocessor, is_training=False) + self.assertTrue(np.allclose(image, actual_eval_image, atol=1e-05)) + + # Tests training image. + image1 = _get_preprocessed_image(preprocessor, is_training=True) + image2 = _get_preprocessed_image(preprocessor, is_training=True) + self.assertFalse(np.allclose(image1, image2, atol=1e-05)) + self.assertEqual(image1.shape, (2, 2, 3)) + self.assertEqual(image2.shape, (2, 2, 3)) + + +if __name__ == '__main__': + tf.compat.v1.disable_eager_execution() + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/image_classifier/BUILD b/mediapipe/model_maker/python/vision/image_classifier/BUILD new file mode 100644 index 000000000..a2268059f --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/BUILD @@ -0,0 +1,111 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict library and test compatibility macro. +# Placeholder for internal Python library rule. + +licenses(["notice"]) + +package( + default_visibility = ["//mediapipe:__subpackages__"], +) + +py_library( + name = "image_classifier_import", + srcs = ["__init__.py"], + deps = [ + ":dataset", + ":hyperparameters", + ":image_classifier", + ":model_spec", + ], +) + +py_library( + name = "model_spec", + srcs = ["model_spec.py"], +) + +py_test( + name = "model_spec_test", + srcs = ["model_spec_test.py"], + deps = [":model_spec"], +) + +py_library( + name = "dataset", + srcs = ["dataset.py"], + deps = ["//mediapipe/model_maker/python/core/data:classification_dataset"], +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + deps = [":dataset"], +) + +py_library( + name = "hyperparameters", + srcs = ["hyperparameters.py"], +) + +py_library( + name = "train_image_classifier_lib", + srcs = ["train_image_classifier_lib.py"], + deps = [ + ":hyperparameters", + "//mediapipe/model_maker/python/core/utils:model_util", + ], +) + +py_library( + name = "image_classifier", + srcs = ["image_classifier.py"], + deps = [ + ":hyperparameters", + ":model_spec", + ":train_image_classifier_lib", + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/core/tasks:classifier", + "//mediapipe/model_maker/python/core/utils:model_util", + "//mediapipe/model_maker/python/core/utils:quantization", + "//mediapipe/model_maker/python/vision/core:image_preprocessing", + ], +) + +py_library( + name = "image_classifier_test_lib", + testonly = 1, + srcs = ["image_classifier_test.py"], + deps = [":image_classifier_import"], +) + +py_test( + name = "image_classifier_test", + srcs = ["image_classifier_test.py"], + shard_count = 2, + tags = ["requires-net:external"], + deps = [ + ":image_classifier_test_lib", + ], +) + +py_binary( + name = "image_classifier_demo", + srcs = ["image_classifier_demo.py"], + deps = [ + ":image_classifier_import", + "//mediapipe/model_maker/python/core/utils:quantization", + ], +) diff --git a/mediapipe/model_maker/python/vision/image_classifier/__init__.py b/mediapipe/model_maker/python/vision/image_classifier/__init__.py new file mode 100644 index 000000000..3ba6b0764 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/__init__.py @@ -0,0 +1,25 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe Model Maker Python Public API For Image Classifier.""" + +from mediapipe.model_maker.python.vision.image_classifier import dataset +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters +from mediapipe.model_maker.python.vision.image_classifier import image_classifier +from mediapipe.model_maker.python.vision.image_classifier import model_spec + +ImageClassifier = image_classifier.ImageClassifier +HParams = hyperparameters.HParams +Dataset = dataset.Dataset +ModelSpec = model_spec.ModelSpec +SupportedModels = model_spec.SupportedModels diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset.py b/mediapipe/model_maker/python/vision/image_classifier/dataset.py new file mode 100644 index 000000000..4ae8dcfdd --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset.py @@ -0,0 +1,109 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image classifier dataset library.""" + +import os +import random + +from typing import List, Optional +import tensorflow as tf +import tensorflow_datasets as tfds + +from mediapipe.model_maker.python.core.data import classification_dataset + + +def _load_image(path: str) -> tf.Tensor: + """Loads a jpeg/png image and returns an image tensor.""" + image_raw = tf.io.read_file(path) + image_tensor = tf.cond( + tf.io.is_jpeg(image_raw), + lambda: tf.io.decode_jpeg(image_raw, channels=3), + lambda: tf.io.decode_png(image_raw, channels=3)) + return image_tensor + + +def _create_data( + name: str, data: tf.data.Dataset, info: tfds.core.DatasetInfo, + label_names: List[str] +) -> Optional[classification_dataset.ClassificationDataset]: + """Creates a Dataset object from tfds data.""" + if name not in data: + return None + data = data[name] + data = data.map(lambda a: (a['image'], a['label'])) + size = info.splits[name].num_examples + return Dataset(data, size, label_names) + + +class Dataset(classification_dataset.ClassificationDataset): + """Dataset library for image classifier.""" + + @classmethod + def from_folder( + cls, + dirname: str, + shuffle: bool = True) -> classification_dataset.ClassificationDataset: + """Loads images and labels from the given directory. + + Assume the image data of the same label are in the same subdirectory. + + Args: + dirname: Name of the directory containing the data files. + shuffle: boolean, if true, random shuffle data. + + Returns: + Dataset containing images and labels and other related info. + Raises: + ValueError: if the input data directory is empty. + """ + data_root = os.path.abspath(dirname) + + # Assumes the image data of the same label are in the same subdirectory, + # gets image path and label names. + all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*')) + all_image_size = len(all_image_paths) + if all_image_size == 0: + raise ValueError('Image size is zero') + + if shuffle: + # Random shuffle data. + random.shuffle(all_image_paths) + + label_names = sorted( + name for name in os.listdir(data_root) + if os.path.isdir(os.path.join(data_root, name))) + all_label_size = len(label_names) + index_by_label = dict( + (name, index) for index, name in enumerate(label_names)) + all_image_labels = [ + index_by_label[os.path.basename(os.path.dirname(path))] + for path in all_image_paths + ] + + path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) + + image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE) + + # Load label + label_ds = tf.data.Dataset.from_tensor_slices( + tf.cast(all_image_labels, tf.int64)) + + # Create a dataset if (image, label) pairs + image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) + + tf.compat.v1.logging.info( + 'Load image with size: %d, num_label: %d, labels: %s.', all_image_size, + all_label_size, ', '.join(label_names)) + return Dataset( + dataset=image_label_ds, size=all_image_size, index_by_label=label_names) diff --git a/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py new file mode 100644 index 000000000..6a0b696f9 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/dataset_test.py @@ -0,0 +1,108 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.vision.image_classifier import dataset + + +def _fill_image(rgb, image_size): + r, g, b = rgb + return np.broadcast_to( + np.array([[[r, g, b]]], dtype=np.uint8), + shape=(image_size, image_size, 3)) + + +def _write_filled_jpeg_file(path, rgb, image_size): + tf.keras.preprocessing.image.save_img(path, _fill_image(rgb, image_size), + 'channels_last', 'jpeg') + + +class DatasetTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + self.image_path = os.path.join(self.get_temp_dir(), 'random_image_dir') + if os.path.exists(self.image_path): + return + os.mkdir(self.image_path) + for class_name in ('daisy', 'tulips'): + class_subdir = os.path.join(self.image_path, class_name) + os.mkdir(class_subdir) + _write_filled_jpeg_file( + os.path.join(class_subdir, '0.jpeg'), + [random.uniform(0, 255) for _ in range(3)], 224) + + def test_split(self): + ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]]) + data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg']) + train_data, test_data = data.split(fraction=0.5) + + self.assertLen(train_data, 2) + for i, elem in enumerate(train_data._dataset): + self.assertTrue((elem.numpy() == np.array([i, 1])).all()) + self.assertEqual(train_data.num_classes, 2) + self.assertEqual(train_data.index_by_label, ['pos', 'neg']) + + self.assertLen(test_data, 2) + for i, elem in enumerate(test_data._dataset): + self.assertTrue((elem.numpy() == np.array([i, 0])).all()) + self.assertEqual(test_data.num_classes, 2) + self.assertEqual(test_data.index_by_label, ['pos', 'neg']) + + def test_from_folder(self): + data = dataset.Dataset.from_folder(dirname=self.image_path) + + self.assertLen(data, 2) + self.assertEqual(data.num_classes, 2) + self.assertEqual(data.index_by_label, ['daisy', 'tulips']) + for image, label in data.gen_tf_dataset(): + self.assertTrue(label.numpy() == 1 or label.numpy() == 0) + if label.numpy() == 0: + raw_image_tensor = dataset._load_image( + os.path.join(self.image_path, 'daisy', '0.jpeg')) + else: + raw_image_tensor = dataset._load_image( + os.path.join(self.image_path, 'tulips', '0.jpeg')) + self.assertTrue((image.numpy() == raw_image_tensor.numpy()).all()) + + def test_from_tfds(self): + # TODO: Remove this once tfds download error is fixed. + self.skipTest('Temporarily skip the unittest due to tfds download error.') + train_data, validation_data, test_data = ( + dataset.Dataset.from_tfds('beans')) + self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset) + self.assertLen(train_data, 1034) + self.assertEqual(train_data.num_classes, 3) + self.assertEqual(train_data.index_by_label, + ['angular_leaf_spot', 'bean_rust', 'healthy']) + + self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset) + self.assertLen(validation_data, 133) + self.assertEqual(validation_data.num_classes, 3) + self.assertEqual(validation_data.index_by_label, + ['angular_leaf_spot', 'bean_rust', 'healthy']) + + self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset) + self.assertLen(test_data, 128) + self.assertEqual(test_data.num_classes, 3) + self.assertEqual(test_data.index_by_label, + ['angular_leaf_spot', 'bean_rust', 'healthy']) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py b/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py new file mode 100644 index 000000000..6df18579a --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/hyperparameters.py @@ -0,0 +1,74 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hyperparameters for training image classification models.""" + +import dataclasses +import tempfile +from typing import Optional + + +# TODO: Expose other hyperparameters, e.g. data augmentation +# hyperparameters if requested. +@dataclasses.dataclass +class HParams: + """The hyperparameters for training image classifiers. + + The hyperparameters include: + # Parameters about training data. + do_fine_tuning: If true, the base module is trained together with the + classification layer on top. + shuffle: A boolean controlling if shuffle the dataset. Default to false. + + # Parameters about training configuration + train_epochs: Training will do this many iterations over the dataset. + batch_size: Each training step samples a batch of this many images. + learning_rate: The learning rate to use for gradient descent training. + dropout_rate: The fraction of the input units to drop, used in dropout + layer. + l1_regularizer: A regularizer that applies a L1 regularization penalty. + l2_regularizer: A regularizer that applies a L2 regularization penalty. + label_smoothing: Amount of label smoothing to apply. See tf.keras.losses for + more details. + do_data_augmentation: A boolean controlling whether the training dataset is + augmented by randomly distorting input images, including random cropping, + flipping, etc. See utils.image_preprocessing documentation for details. + steps_per_epoch: An optional integer indicate the number of training steps + per epoch. If not set, the training pipeline calculates the default steps + per epoch as the training dataset size devided by batch size. + decay_samples: Number of training samples used to calculate the decay steps + and create the training optimizer. + warmup_steps: Number of warmup steps for a linear increasing warmup schedule + on learning rate. Used to set up warmup schedule by model_util.WarmUp. + + # Parameters about the saved checkpoint + model_dir: The location of model checkpoint files and exported model files. + """ + # Parameters about training data + do_fine_tuning: bool = False + shuffle: bool = False + # Parameters about training configuration + train_epochs: int = 5 + batch_size: int = 32 + learning_rate: float = 0.005 + dropout_rate: float = 0.2 + l1_regularizer: float = 0.0 + l2_regularizer: float = 0.0001 + label_smoothing: float = 0.1 + do_data_augmentation: bool = True + steps_per_epoch: Optional[int] = None + decay_samples: int = 10000 * 256 + warmup_epochs: int = 2 + + # Parameters about the saved checkpoint + model_dir: str = tempfile.mkdtemp() diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py new file mode 100644 index 000000000..dd8929a71 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier.py @@ -0,0 +1,172 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""APIs to train image classifier model.""" + +from typing import Any, List, Optional + +import tensorflow as tf +import tensorflow_hub as hub + +from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds +from mediapipe.model_maker.python.core.tasks import classifier +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision.core import image_preprocessing +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp +from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms +from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib + + +class ImageClassifier(classifier.Classifier): + """ImageClassifier for building image classification model.""" + + def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any], + hparams: hp.HParams): + """Initializes ImageClassifier class. + + Args: + model_spec: Specification for the model. + index_by_label: A list that maps from index to label class name. + hparams: The hyperparameters for training image classifier. + """ + super().__init__( + model_spec=model_spec, + index_by_label=index_by_label, + shuffle=hparams.shuffle, + full_train=hparams.do_fine_tuning) + self._hparams = hparams + self._preprocess = image_preprocessing.Preprocessor( + input_shape=self._model_spec.input_image_shape, + num_classes=self._num_classes, + mean_rgb=self._model_spec.mean_rgb, + stddev_rgb=self._model_spec.stddev_rgb, + use_augmentation=hparams.do_data_augmentation) + self._history = None # Training history returned from `keras_model.fit`. + + @classmethod + def create( + cls, + model_spec: ms.SupportedModels, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + hparams: Optional[hp.HParams] = None, + ) -> 'ImageClassifier': + """Creates and trains an image classifier. + + Loads data and trains the model based on data for image classification. + + Args: + model_spec: Specification for the model. + train_data: Training data. + validation_data: Validation data. + hparams: Hyperparameters for training image classifier. + + Returns: + An instance based on ImageClassifier. + """ + if hparams is None: + hparams = hp.HParams() + + spec = ms.SupportedModels.get(model_spec) + image_classifier = cls( + model_spec=spec, + index_by_label=train_data.index_by_label, + hparams=hparams) + + image_classifier._create_model() + + tf.compat.v1.logging.info('Training the models...') + image_classifier._train( + train_data=train_data, validation_data=validation_data) + + return image_classifier + + def _train(self, train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset): + """Trains the model with input train_data. + + The training results are recorded by a self._history object returned by + tf.keras.Model.fit(). + + Args: + train_data: Training data. + validation_data: Validation data. + """ + + tf.compat.v1.logging.info('Training the models...') + hparams = self._hparams + if len(train_data) < hparams.batch_size: + raise ValueError('The size of the train_data (%d) couldn\'t be smaller ' + 'than batch_size (%d). To solve this problem, set ' + 'the batch_size smaller or increase the size of the ' + 'train_data.' % (len(train_data), hparams.batch_size)) + + train_dataset = train_data.gen_tf_dataset( + batch_size=hparams.batch_size, + is_training=True, + shuffle=self._shuffle, + preprocess=self._preprocess) + hparams.steps_per_epoch = model_util.get_steps_per_epoch( + steps_per_epoch=hparams.steps_per_epoch, + batch_size=hparams.batch_size, + train_data=train_data) + train_dataset = train_dataset.take(count=hparams.steps_per_epoch) + + validation_dataset = validation_data.gen_tf_dataset( + batch_size=hparams.batch_size, + is_training=False, + preprocess=self._preprocess) + + # Train the model. + self._history = train_image_classifier_lib.train_model( + model=self._model, + hparams=hparams, + train_ds=train_dataset, + validation_ds=validation_dataset) + + def _create_model(self): + """Creates the classifier model from TFHub pretrained models.""" + module_layer = hub.KerasLayer( + handle=self._model_spec.uri, trainable=self._hparams.do_fine_tuning) + + image_size = self._model_spec.input_image_shape + + self._model = tf.keras.Sequential([ + tf.keras.Input(shape=(image_size[0], image_size[1], 3)), module_layer, + tf.keras.layers.Dropout(rate=self._hparams.dropout_rate), + tf.keras.layers.Dense( + units=self._num_classes, + activation='softmax', + kernel_regularizer=tf.keras.regularizers.l1_l2( + l1=self._hparams.l1_regularizer, + l2=self._hparams.l2_regularizer)) + ]) + print(self._model.summary()) + + def export_model( + self, + model_name: str = 'model.tflite', + quantization_config: Optional[quantization.QuantizationConfig] = None): + """Converts the model to the requested formats and exports to a file. + + Args: + model_name: File name to save tflite model. The full export path is + {export_dir}/{tflite_filename}. + quantization_config: The configuration for model quantization. + """ + super().export_tflite( + self._hparams.model_dir, + model_name, + quantization_config, + preprocess=self._preprocess) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py new file mode 100644 index 000000000..5832ea53a --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_demo.py @@ -0,0 +1,106 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Demo for making an image classifier model by MediaPipe Model Maker.""" + +import os + +# Dependency imports + +from absl import app +from absl import flags +from absl import logging +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import quantization +from mediapipe.model_maker.python.vision import image_classifier + +FLAGS = flags.FLAGS + + +def define_flags() -> None: + """Define flags for the image classifier model maker demo.""" + flags.DEFINE_string('export_dir', None, + 'The directory to save exported files.') + flags.DEFINE_string( + 'input_data_dir', None, + """The directory with input training data. If the training data is not + specified, the pipeline will download a default training dataset.""") + flags.DEFINE_enum_class('spec', + image_classifier.SupportedModels.EFFICIENTNET_LITE0, + image_classifier.SupportedModels, + 'The image classifier to run.') + flags.DEFINE_enum('quantization', None, ['dynamic', 'int8', 'float16'], + 'The quantization method to use when exporting the model.') + flags.mark_flag_as_required('export_dir') + + +def download_demo_data() -> str: + """Downloads demo data, and returns directory path.""" + data_dir = tf.keras.utils.get_file( + fname='flower_photos.tgz', + origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', + extract=True) + return os.path.join(os.path.dirname(data_dir), 'flower_photos') # folder name + + +def run(data_dir: str, export_dir: str, + model_spec: image_classifier.SupportedModels, + quantization_option: str) -> None: + """Runs demo.""" + data = image_classifier.Dataset.from_folder(data_dir) + train_data, rest_data = data.split(0.8) + validation_data, test_data = rest_data.split(0.5) + + model = image_classifier.ImageClassifier.create( + model_spec=model_spec, + train_data=train_data, + validation_data=validation_data, + hparams=image_classifier.HParams(model_dir=export_dir)) + + _, acc = model.evaluate(test_data) + print('Test accuracy: %f' % acc) + + if quantization_option is None: + quantization_config = None + elif quantization_option == 'dynamic': + quantization_config = quantization.QuantizationConfig.for_dynamic() + elif quantization_option == 'int8': + quantization_config = quantization.QuantizationConfig.for_int8(train_data) + elif quantization_option == 'float16': + quantization_config = quantization.QuantizationConfig.for_float16() + else: + raise ValueError(f'Quantization: {quantization} is not recognized') + + model.export_model(quantization_config=quantization_config) + model.export_labels(export_dir) + + +def main(_) -> None: + logging.set_verbosity(logging.INFO) + + if FLAGS.input_data_dir is None: + data_dir = download_demo_data() + else: + data_dir = FLAGS.input_data_dir + + export_dir = os.path.expanduser(FLAGS.export_dir) + run(data_dir=data_dir, + export_dir=export_dir, + model_spec=FLAGS.spec, + quantization_option=FLAGS.quantization) + + +if __name__ == '__main__': + define_flags() + app.run(main) diff --git a/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py new file mode 100644 index 000000000..8ed6de7ad --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/image_classifier_test.py @@ -0,0 +1,107 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from mediapipe.model_maker.python.vision import image_classifier + + +def _fill_image(rgb, image_size): + r, g, b = rgb + return np.broadcast_to( + np.array([[[r, g, b]]], dtype=np.uint8), + shape=(image_size, image_size, 3)) + + +class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase): + IMAGE_SIZE = 24 + IMAGES_PER_CLASS = 2 + CMY_NAMES_AND_RGB_VALUES = (('cyan', (0, 255, 255)), + ('magenta', (255, 0, 255)), ('yellow', (255, 255, + 0))) + + def _gen(self): + for i, (_, rgb) in enumerate(self.CMY_NAMES_AND_RGB_VALUES): + for _ in range(self.IMAGES_PER_CLASS): + yield (_fill_image(rgb, self.IMAGE_SIZE), i) + + def _gen_cmy_data(self): + ds = tf.data.Dataset.from_generator( + self._gen, (tf.uint8, tf.int64), (tf.TensorShape( + [self.IMAGE_SIZE, self.IMAGE_SIZE, 3]), tf.TensorShape([]))) + data = image_classifier.Dataset(ds, self.IMAGES_PER_CLASS * 3, + ['cyan', 'magenta', 'yellow']) + return data + + def setUp(self): + super(ImageClassifierTest, self).setUp() + all_data = self._gen_cmy_data() + # Splits data, 90% data for training, 10% for testing + self.train_data, self.test_data = all_data.split(0.9) + + @parameterized.named_parameters( + dict( + testcase_name='mobilenet_v2', + model_spec=image_classifier.SupportedModels.MOBILENET_V2, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite0', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite2', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + dict( + testcase_name='efficientnet_lite4', + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4, + hparams=image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True)), + ) + def test_create_and_train_model(self, + model_spec: image_classifier.SupportedModels, + hparams: image_classifier.HParams): + model = image_classifier.ImageClassifier.create( + model_spec=model_spec, + train_data=self.train_data, + hparams=hparams, + validation_data=self.test_data) + self._test_accuracy(model) + + def test_efficientnetlite0_model_with_model_maker_retraining_lib(self): + hparams = image_classifier.HParams( + train_epochs=1, batch_size=1, shuffle=True) + model = image_classifier.ImageClassifier.create( + model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0, + train_data=self.train_data, + hparams=hparams, + validation_data=self.test_data) + self._test_accuracy(model) + + def _test_accuracy(self, model, threshold=0.0): + _, accuracy = model.evaluate(self.test_data) + self.assertGreaterEqual(accuracy, threshold) + + +if __name__ == '__main__': + # Load compressed models from tensorflow_hub + os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py new file mode 100644 index 000000000..ef44f86e6 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py @@ -0,0 +1,84 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image classifier model specification.""" + +import enum +import functools +from typing import List, Optional + + +class ModelSpec(object): + """Specification of image classifier model.""" + + mean_rgb = [0.0] + stddev_rgb = [255.0] + + def __init__(self, + uri: str, + input_image_shape: Optional[List[int]] = None, + name: str = ''): + """Initializes a new instance of the `ImageModelSpec` class. + + Args: + uri: str, URI to the pretrained model. + input_image_shape: list of int, input image shape. Default: [224, 224]. + name: str, model spec name. + """ + self.uri = uri + self.name = name + + if input_image_shape is None: + input_image_shape = [224, 224] + self.input_image_shape = input_image_shape + + +mobilenet_v2_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', + name='mobilenet_v2') + +efficientnet_lite0_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', + name='efficientnet_lite0') + +efficientnet_lite2_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', + input_image_shape=[260, 260], + name='efficientnet_lite2') + +efficientnet_lite4_spec = functools.partial( + ModelSpec, + uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', + input_image_shape=[300, 300], + name='efficientnet_lite4') + + +# TODO: Document the exposed models. +@enum.unique +class SupportedModels(enum.Enum): + """Image classifier model supported by model maker.""" + MOBILENET_V2 = mobilenet_v2_spec + EFFICIENTNET_LITE0 = efficientnet_lite0_spec + EFFICIENTNET_LITE2 = efficientnet_lite2_spec + EFFICIENTNET_LITE4 = efficientnet_lite4_spec + + @classmethod + def get(cls, spec: 'SupportedModels') -> 'ModelSpec': + """Gets model spec from the input enum and initializes it.""" + if spec not in cls: + raise TypeError('Unsupported image classifier spec: {}'.format(spec)) + + return spec.value() diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py new file mode 100644 index 000000000..63f360ab9 --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec_test.py @@ -0,0 +1,75 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from typing import Callable, List +from absl.testing import parameterized +import tensorflow as tf + +from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms + + +class ModelSpecTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='mobilenet_v2_spec_test', + model_spec=ms.mobilenet_v2_spec, + expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4', + expected_name='mobilenet_v2', + expected_input_image_shape=[224, 224]), + dict( + testcase_name='efficientnet_lite0_spec_test', + model_spec=ms.efficientnet_lite0_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2', + expected_name='efficientnet_lite0', + expected_input_image_shape=[224, 224]), + dict( + testcase_name='efficientnet_lite2_spec_test', + model_spec=ms.efficientnet_lite2_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2', + expected_name='efficientnet_lite2', + expected_input_image_shape=[260, 260]), + dict( + testcase_name='efficientnet_lite4_spec_test', + model_spec=ms.efficientnet_lite4_spec, + expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2', + expected_name='efficientnet_lite4', + expected_input_image_shape=[300, 300]), + ) + def test_predefiend_spec(self, model_spec: Callable[..., ms.ModelSpec], + expected_uri: str, expected_name: str, + expected_input_image_shape: List[int]): + model_spec_obj = model_spec() + self.assertIsInstance(model_spec_obj, ms.ModelSpec) + self.assertEqual(model_spec_obj.uri, expected_uri) + self.assertEqual(model_spec_obj.name, expected_name) + self.assertEqual(model_spec_obj.input_image_shape, + expected_input_image_shape) + + def test_create_spec(self): + custom_model_spec = ms.ModelSpec( + uri='https://custom_model', + input_image_shape=[128, 128], + name='custom_model') + self.assertEqual(custom_model_spec.uri, 'https://custom_model') + self.assertEqual(custom_model_spec.name, 'custom_model') + self.assertEqual(custom_model_spec.input_image_shape, [128, 128]) + + +if __name__ == '__main__': + # Load compressed models from tensorflow_hub + os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED' + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py new file mode 100644 index 000000000..704d71a5a --- /dev/null +++ b/mediapipe/model_maker/python/vision/image_classifier/train_image_classifier_lib.py @@ -0,0 +1,103 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to train model.""" + +import os +from typing import List + +import tensorflow as tf + +from mediapipe.model_maker.python.core.utils import model_util +from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp + + +def _create_optimizer(init_lr: float, decay_steps: int, + warmup_steps: int) -> tf.keras.optimizers.Optimizer: + """Creates an optimizer with learning rate schedule. + + Uses Keras CosineDecay schedule for the learning rate by default. + + Args: + init_lr: Initial learning rate. + decay_steps: Number of steps to decay over. + warmup_steps: Number of steps to do warmup for. + + Returns: + A tf.keras.optimizers.Optimizer for model training. + """ + learning_rate_fn = tf.keras.experimental.CosineDecay( + initial_learning_rate=init_lr, decay_steps=decay_steps, alpha=0.0) + if warmup_steps: + learning_rate_fn = model_util.WarmUp( + initial_learning_rate=init_lr, + decay_schedule_fn=learning_rate_fn, + warmup_steps=warmup_steps) + optimizer = tf.keras.optimizers.RMSprop( + learning_rate=learning_rate_fn, rho=0.9, momentum=0.9, epsilon=0.001) + + return optimizer + + +def _get_default_callbacks(model_dir: str) -> List[tf.keras.callbacks.Callback]: + """Gets default callbacks.""" + summary_dir = os.path.join(model_dir, 'summaries') + summary_callback = tf.keras.callbacks.TensorBoard(summary_dir) + # Save checkpoint every 20 epochs. + + checkpoint_path = os.path.join(model_dir, 'checkpoint') + checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( + checkpoint_path, save_weights_only=True, period=20) + return [summary_callback, checkpoint_callback] + + +def train_model(model: tf.keras.Model, hparams: hp.HParams, + train_ds: tf.data.Dataset, + validation_ds: tf.data.Dataset) -> tf.keras.callbacks.History: + """Trains model with the input data and hyperparameters. + + Args: + model: Input tf.keras.Model. + hparams: Hyperparameters for training image classifier. + train_ds: tf.data.Dataset, training data to be fed in tf.keras.Model.fit(). + validation_ds: tf.data.Dataset, validation data to be fed in + tf.keras.Model.fit(). + + Returns: + The tf.keras.callbacks.History object returned by tf.keras.Model.fit(). + """ + + # Learning rate is linear to batch size. + learning_rate = hparams.learning_rate * hparams.batch_size / 256 + + # Get decay steps. + total_training_steps = hparams.steps_per_epoch * hparams.train_epochs + default_decay_steps = hparams.decay_samples // hparams.batch_size + decay_steps = max(total_training_steps, default_decay_steps) + + warmup_steps = hparams.warmup_epochs * hparams.steps_per_epoch + optimizer = _create_optimizer( + init_lr=learning_rate, decay_steps=decay_steps, warmup_steps=warmup_steps) + + loss = tf.keras.losses.CategoricalCrossentropy( + label_smoothing=hparams.label_smoothing) + model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) + callbacks = _get_default_callbacks(hparams.model_dir) + + # Train the model. + return model.fit( + x=train_ds, + epochs=hparams.train_epochs, + steps_per_epoch=hparams.steps_per_epoch, + validation_data=validation_ds, + callbacks=callbacks) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt new file mode 100644 index 000000000..389ee484a --- /dev/null +++ b/mediapipe/model_maker/requirements.txt @@ -0,0 +1,6 @@ +absl-py +numpy +opencv-contrib-python +tensorflow +tensorflow-datasets +tensorflow-hub diff --git a/mediapipe/objc/MPPCameraInputSource.m b/mediapipe/objc/MPPCameraInputSource.m index b9718680c..73b3549c4 100644 --- a/mediapipe/objc/MPPCameraInputSource.m +++ b/mediapipe/objc/MPPCameraInputSource.m @@ -244,7 +244,7 @@ if ([_session canAddOutput:_depthDataOutput]) { [_session addOutput:_depthDataOutput]; - AVCaptureConnection* connection = + AVCaptureConnection* __unused connection = [_depthDataOutput connectionWithMediaType:AVMediaTypeDepthData]; // Set this when we have a handler. @@ -327,7 +327,6 @@ if (depthData.depthDataType != kCVPixelFormatType_DepthFloat32) { depthData = [depthData depthDataByConvertingToDepthDataType:kCVPixelFormatType_DepthFloat32]; } - CVPixelBufferRef depthBuffer = depthData.depthDataMap; [self.delegate processDepthData:depthData timestamp:timestamp fromSource:self]; } diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 67d71720e..080cca20f 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -134,12 +134,12 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, if (format == mediapipe::ImageFormat::SRGBA) { // Swap R and B channels. const uint8_t permuteMap[4] = {2, 1, 0, 3}; - vImage_Error vError = vImagePermuteChannels_ARGB8888( - &vSource, &vDestination, permuteMap, kvImageNoFlags); + vImage_Error __unused vError = + vImagePermuteChannels_ARGB8888(&vSource, &vDestination, permuteMap, kvImageNoFlags); _GTMDevAssert(vError == kvImageNoError, @"vImagePermuteChannels failed: %zd", vError); } else { // Convert grayscale back to BGRA - vImage_Error vError = vImageGrayToBGRA(&vSource, &vDestination); + vImage_Error __unused vError = vImageGrayToBGRA(&vSource, &vDestination); _GTMDevAssert(vError == kvImageNoError, @"vImageGrayToBGRA failed: %zd", vError); } diff --git a/mediapipe/objc/MPPGraphTestBase.mm b/mediapipe/objc/MPPGraphTestBase.mm index ddd15f736..eb4ea0535 100644 --- a/mediapipe/objc/MPPGraphTestBase.mm +++ b/mediapipe/objc/MPPGraphTestBase.mm @@ -32,10 +32,11 @@ static UIImage* UIImageWithPixelBuffer(CVPixelBufferRef pixelBuffer) { static void EnsureOutputDirFor(NSString *outputFile) { NSFileManager *fileManager = [NSFileManager defaultManager]; NSError *error = nil; - BOOL result = [fileManager createDirectoryAtPath:[outputFile stringByDeletingLastPathComponent] - withIntermediateDirectories:YES - attributes:nil - error:&error]; + BOOL __unused result = + [fileManager createDirectoryAtPath:[outputFile stringByDeletingLastPathComponent] + withIntermediateDirectories:YES + attributes:nil + error:&error]; // TODO: Log the error for clarity. The file-write will fail later // but it would be nice to see this error. However, 'error' is still testing // false and result is true even on an unwritable path-- not sure what's up. @@ -89,17 +90,10 @@ static void EnsureOutputDirFor(NSString *outputFile) { __block CVPixelBufferRef output; graph.delegate = self; - // The XCTAssert macros contain references to self, which causes a retain cycle, - // since the block retains self and self retains the block. The cycle is broken - // at the end of this method, with _pixelBufferOutputBlock = nil, but Clang does - // not realize that and outputs a warning. WEAKIFY and STRONGIFY, though not - // strictly necessary, are used here to avoid the warning. - WEAKIFY(self); if (!_pixelBufferOutputBlock) { XCTestExpectation* outputReceived = [self expectationWithDescription:@"output received"]; _pixelBufferOutputBlock = ^(MPPGraph* outputGraph, CVPixelBufferRef outputBuffer, const std::string& outputStreamName) { - STRONGIFY(self); XCTAssertEqualObjects(outputGraph, graph); XCTAssertEqual(outputStreamName, outputStream); CFRetain(outputBuffer); diff --git a/mediapipe/objc/MPPGraphTests.mm b/mediapipe/objc/MPPGraphTests.mm index c3cf48047..a68d4528d 100644 --- a/mediapipe/objc/MPPGraphTests.mm +++ b/mediapipe/objc/MPPGraphTests.mm @@ -287,9 +287,7 @@ REGISTER_CALCULATOR(ErrorCalculator); CFHolder inputBuffer; absl::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); XCTAssert(status.ok()); - CVPixelBufferRef outputBuffer = [self runGraph:_graph - withPixelBuffer:*inputBuffer - packetType:MPPPacketTypePixelBuffer]; + [self runGraph:_graph withPixelBuffer:*inputBuffer packetType:MPPPacketTypePixelBuffer]; __weak MPPGraph* weakGraph = _graph; _graph = nil; XCTAssertNil(weakGraph); diff --git a/mediapipe/python/image_test.py b/mediapipe/python/image_test.py index a21dbdadd..117d20974 100644 --- a/mediapipe/python/image_test.py +++ b/mediapipe/python/image_test.py @@ -187,16 +187,5 @@ class ImageTest(absltest.TestCase): gc.collect() self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) - def test_image_create_from_file(self): - image_path = os.path.join( - resources.GetRunfilesDir(), - 'mediapipe/tasks/testdata/vision/cat.jpg') - loaded_image = Image.create_from_file(image_path) - self.assertEqual(loaded_image.width, 600) - self.assertEqual(loaded_image.height, 400) - self.assertEqual(loaded_image.channels, 3) - self.assertEqual(loaded_image.image_format, ImageFormat.SRGB) - - if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/cc/audio/audio_classifier/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/BUILD index 20ccf68f0..ac238bfda 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/BUILD @@ -35,9 +35,10 @@ cc_library( "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", @@ -64,8 +65,9 @@ cc_library( "//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:running_mode", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc index 9a8075f77..702d802c5 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.cc @@ -24,8 +24,9 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -37,6 +38,8 @@ namespace audio_classifier { namespace { +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; + constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioTag[] = "AUDIO"; constexpr char kClassificationResultStreamName[] = "classification_result_out"; @@ -77,8 +80,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode == core::RunningMode::AUDIO_STREAM); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h index bd8bd5e0c..200cffb8c 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier.h @@ -23,8 +23,8 @@ limitations under the License. #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" namespace mediapipe { @@ -40,7 +40,7 @@ struct AudioClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The running mode of the audio classifier. Default to the audio clips mode. // Audio classifier has two running modes: @@ -59,8 +59,9 @@ struct AudioClassifierOptions { // The user-defined result callback for processing audio stream data. // The result callback should only be specified when the running mode is set // to RunningMode::AUDIO_STREAM. - std::function)> result_callback = - nullptr; + std::function)> + result_callback = nullptr; }; // Performs audio classification on audio clips or audio stream. @@ -132,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi { // framed audio clip. // TODO: Use `sample_rate` in AudioClassifierOptions by default // and makes `audio_sample_rate` optional. - absl::StatusOr Classify(mediapipe::Matrix audio_clip, - double audio_sample_rate); + absl::StatusOr Classify( + mediapipe::Matrix audio_clip, double audio_sample_rate); // Sends audio data (a block in a continuous audio stream) to perform audio // classification. Only use this method when the AudioClassifier is created diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc index 810fb2da5..12f8ce31a 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_graph.cc @@ -31,9 +31,9 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" @@ -53,6 +53,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAudioTag[] = "AUDIO"; @@ -238,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Time aggregation is only needed for performing audio classification on diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 4e874b520..591d5e4f7 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" namespace mediapipe { @@ -49,6 +49,7 @@ namespace { using ::absl::StatusOr; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::testing::HasSubstr; using ::testing::Optional; @@ -183,7 +184,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { EXPECT_THAT( audio_classifier_or.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(audio_classifier_or.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD index 033bb51ac..bfe37ec01 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto index 63b4b3293..16aa86aeb 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.proto @@ -18,7 +18,7 @@ syntax = "proto2"; package mediapipe.tasks.audio.audio_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; message AudioClassifierGraphOptions { @@ -31,7 +31,7 @@ message AudioClassifierGraphOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; // The default sample rate of the input audio. Must be set when the // AudioClassifier is configured to process audio stream data. diff --git a/mediapipe/tasks/cc/common.h b/mediapipe/tasks/cc/common.h index 62656b7b3..1295177df 100644 --- a/mediapipe/tasks/cc/common.h +++ b/mediapipe/tasks/cc/common.h @@ -65,6 +65,8 @@ enum class MediaPipeTasksStatus { kFileReadError, // I/O error when mmap-ing file. kFileMmapError, + // ZIP I/O error when unpacking the zip file. + kFileZipError, // TensorFlow Lite metadata error codes. diff --git a/mediapipe/tasks/cc/components/BUILD b/mediapipe/tasks/cc/components/BUILD index 4de32ce9b..e4905546a 100644 --- a/mediapipe/tasks/cc/components/BUILD +++ b/mediapipe/tasks/cc/components/BUILD @@ -58,65 +58,6 @@ cc_library( # TODO: Enable this test -cc_library( - name = "classifier_options", - srcs = ["classifier_options.cc"], - hdrs = ["classifier_options.h"], - deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"], -) - -mediapipe_proto_library( - name = "classification_postprocessing_options_proto", - srcs = ["classification_postprocessing_options.proto"], - deps = [ - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", - ], -) - -cc_library( - name = "classification_postprocessing", - srcs = ["classification_postprocessing.cc"], - hdrs = ["classification_postprocessing.h"], - deps = [ - ":classification_postprocessing_options_cc_proto", - "//mediapipe/calculators/core:split_vector_calculator", - "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/tensor:tensors_dequantization_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:packet", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", - "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", - "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", - "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", - "//mediapipe/tasks/cc/components/utils:source_or_node_output", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/metadata:metadata_extractor", - "//mediapipe/tasks/metadata:metadata_schema_cc", - "//mediapipe/util:label_map_cc_proto", - "//mediapipe/util:label_map_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - ], - alwayslink = 1, -) - cc_library( name = "embedder_options", srcs = ["embedder_options.cc"], @@ -151,3 +92,29 @@ cc_library( ], alwayslink = 1, ) + +# TODO: Investigate rewriting the build rule to only link +# the Bert Preprocessor if it's needed. +cc_library( + name = "text_preprocessing_graph", + srcs = ["text_preprocessing_graph.cc"], + hdrs = ["text_preprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/tensor:bert_preprocessor_calculator", + "//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator", + "//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto", + "//mediapipe/calculators/tensor:text_to_tensor_calculator", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 13ca6b496..7d01e4dfe 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,8 +37,8 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers:category_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], alwayslink = 1, @@ -128,7 +128,7 @@ cc_library( "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index b2848bc3f..e1f69e607 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,15 +25,15 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions; -using ::mediapipe::tasks::ClassificationResult; -using ::mediapipe::tasks::Classifications; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into a single ClassificationResult that has // 3 dimensions: (classification head, classification timestamp, classification diff --git a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc index b688cda91..10eb962dd 100644 --- a/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/end_loop_calculator.cc @@ -17,12 +17,13 @@ limitations under the License. #include -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" // Specialized EndLoopCalculator for Tasks specific types. namespace mediapipe::tasks { -typedef EndLoopCalculator> +typedef EndLoopCalculator< + std::vector> EndLoopClassificationResultCalculator; REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.h b/mediapipe/tasks/cc/components/classification_postprocessing.h deleted file mode 100644 index eb638bd60..000000000 --- a/mediapipe/tasks/cc/components/classification_postprocessing.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ - -#include "absl/status/status.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" -#include "mediapipe/tasks/cc/core/model_resources.h" - -namespace mediapipe { -namespace tasks { -namespace components { - -// Configures a ClassificationPostprocessing subgraph using the provided model -// resources and ClassifierOptions. -// - Accepts CPU input tensors. -// -// Example usage: -// -// auto& postprocessing = -// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); -// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( -// model_resources, -// classifier_options, -// &preprocessing.GetOptions())); -// -// The resulting ClassificationPostprocessing subgraph has the following I/O: -// Inputs: -// TENSORS - std::vector -// The output tensors of an InferenceCalculator. -// TIMESTAMPS - std::vector @Optional -// The collection of timestamps that a single ClassificationResult should -// aggregate. This is mostly useful for classifiers working on time series, -// e.g. audio or video classification. -// Outputs: -// CLASSIFICATION_RESULT - ClassificationResult -// The output aggregated classification results. -absl::Status ConfigureClassificationPostprocessing( - const tasks::core::ModelResources& model_resources, - const tasks::components::proto::ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options); - -} // namespace components -} // namespace tasks -} // namespace mediapipe - -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index 701f84824..af51d0c37 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -12,21 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") - package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], +cc_library( + name = "rect", + hdrs = ["rect.h"], ) -mediapipe_proto_library( - name = "classifications_proto", - srcs = ["classifications.proto"], +cc_library( + name = "gesture_recognition_result", + hdrs = ["gesture_recognition_result.h"], deps = [ - ":category_proto", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/gesture_recognition_result.h b/mediapipe/tasks/cc/components/containers/gesture_recognition_result.h new file mode 100644 index 000000000..4e2e8d775 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/gesture_recognition_result.h @@ -0,0 +1,46 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_ + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace containers { + +// The gesture recognition result from GestureRecognizer, where each vector +// element represents a single hand detected in the image. +struct GestureRecognitionResult { + // Recognized hand gestures with sorted order such that the winning label is + // the first item in the list. + std::vector gestures; + // Classification of handedness. + std::vector handedness; + // Detected hand landmarks in normalized image coordinates. + std::vector hand_landmarks; + // Detected hand landmarks in world coordinates. + std::vector hand_world_landmarks; +}; + +} // namespace containers +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_GESTURE_RECOGNITION_RESULT_H_ diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 9c6402e64..633b5b369 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) +mediapipe_proto_library( + name = "category_proto", + srcs = ["category.proto"], +) + +mediapipe_proto_library( + name = "classifications_proto", + srcs = ["classifications.proto"], + deps = [ + ":category_proto", + ], +) + +mediapipe_proto_library( + name = "embeddings_proto", + srcs = ["embeddings.proto"], +) + mediapipe_proto_library( name = "landmarks_detection_result_proto", srcs = [ @@ -29,8 +47,3 @@ mediapipe_proto_library( "//mediapipe/framework/formats:rect_proto", ], ) - -mediapipe_proto_library( - name = "embeddings_proto", - srcs = ["embeddings.proto"], -) diff --git a/mediapipe/tasks/cc/components/containers/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto similarity index 88% rename from mediapipe/tasks/cc/components/containers/category.proto rename to mediapipe/tasks/cc/components/containers/proto/category.proto index 47f38b75a..2ba760e99 100644 --- a/mediapipe/tasks/cc/components/containers/category.proto +++ b/mediapipe/tasks/cc/components/containers/proto/category.proto @@ -15,7 +15,10 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; + +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; +option java_outer_classname = "CategoryProto"; // A single classification result. message Category { diff --git a/mediapipe/tasks/cc/components/containers/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto similarity index 86% rename from mediapipe/tasks/cc/components/containers/classifications.proto rename to mediapipe/tasks/cc/components/containers/proto/classifications.proto index 469c67fc9..712607fa6 100644 --- a/mediapipe/tasks/cc/components/containers/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -15,9 +15,12 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks; +package mediapipe.tasks.components.containers.proto; -import "mediapipe/tasks/cc/components/containers/category.proto"; +import "mediapipe/tasks/cc/components/containers/proto/category.proto"; + +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; +option java_outer_classname = "ClassificationsProto"; // List of predicted categories with an optional timestamp. message ClassificationEntry { diff --git a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto index d57b08b53..39811e6c0 100644 --- a/mediapipe/tasks/cc/components/containers/proto/embeddings.proto +++ b/mediapipe/tasks/cc/components/containers/proto/embeddings.proto @@ -17,6 +17,9 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; +option java_outer_classname = "EmbeddingsProto"; + // Defines a dense floating-point embedding. message FloatEmbedding { repeated float values = 1 [packed = true]; diff --git a/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto b/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto index 9be6ce47a..ac44f9b58 100644 --- a/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto +++ b/mediapipe/tasks/cc/components/containers/proto/landmarks_detection_result.proto @@ -21,6 +21,9 @@ import "mediapipe/framework/formats/classification.proto"; import "mediapipe/framework/formats/landmark.proto"; import "mediapipe/framework/formats/rect.proto"; +option java_package = "com.google.mediapipe.tasks.components.containers.proto"; +option java_outer_classname = "LandmarksDetectionResultProto"; + message LandmarksDetectionResult { optional mediapipe.NormalizedLandmarkList landmarks = 1; optional mediapipe.ClassificationList classifications = 2; diff --git a/mediapipe/tasks/cc/components/containers/rect.h b/mediapipe/tasks/cc/components/containers/rect.h new file mode 100644 index 000000000..3f5432cf2 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/rect.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ + +namespace mediapipe::tasks::components::containers { + +// Defines a rectangle, used e.g. as part of detection results or as input +// region-of-interest. +// +// The coordinates are normalized wrt the image dimensions, i.e. generally in +// [0,1] but they may exceed these bounds if describing a region overlapping the +// image. The origin is on the top-left corner of the image. +struct Rect { + float left; + float top; + float right; + float bottom; +}; + +} // namespace mediapipe::tasks::components::containers +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_ diff --git a/mediapipe/tasks/cc/components/processors/BUILD b/mediapipe/tasks/cc/components/processors/BUILD new file mode 100644 index 000000000..62f04dcb7 --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/BUILD @@ -0,0 +1,64 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "classifier_options", + srcs = ["classifier_options.cc"], + hdrs = ["classifier_options.h"], + deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"], +) + +cc_library( + name = "classification_postprocessing_graph", + srcs = ["classification_postprocessing_graph.cc"], + hdrs = ["classification_postprocessing_graph.h"], + deps = [ + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator_cc_proto", + "//mediapipe/calculators/tensor:tensors_dequantization_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/utils:source_or_node_output", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:label_map_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite/schema:schema_fbs", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc similarity index 90% rename from mediapipe/tasks/cc/components/classification_postprocessing.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 871476e8f..b4fbf9669 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include @@ -37,9 +37,9 @@ limitations under the License. #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" @@ -51,6 +51,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { @@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; @@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; // Performs sanity checks on provided ClassifierOptions. -absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) { +absl::Status SanityCheckClassifierOptions( + const proto::ClassifierOptions& options) { if (options.max_results() == 0) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, @@ -121,15 +123,17 @@ absl::StatusOr GetClassificationHeadsProperties( const auto* tensor = primary_subgraph->tensors()->Get(primary_subgraph->outputs()->Get(i)); if (tensor->type() != tflite::TensorType_FLOAT32 && - tensor->type() != tflite::TensorType_UINT8) { + tensor->type() != tflite::TensorType_UINT8 && + tensor->type() != tflite::TensorType_BOOL) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrFormat("Expected output tensor at index %d to have type " - "UINT8 or FLOAT32, found %s instead.", + "UINT8 or FLOAT32 or BOOL, found %s instead.", i, tflite::EnumNameTensorType(tensor->type())), MediaPipeTasksStatus::kInvalidOutputTensorTypeError); } - if (tensor->type() == tflite::TensorType_UINT8) { + if (tensor->type() == tflite::TensorType_UINT8 || + tensor->type() == tflite::TensorType_BOOL) { num_quantized_tensors++; } } @@ -203,7 +207,7 @@ absl::StatusOr GetScoreThreshold( // Gets the category allowlist or denylist (if any) as a set of indices. absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( - const ClassifierOptions& options, const LabelItems& label_items) { + const proto::ClassifierOptions& options, const LabelItems& label_items) { absl::flat_hash_set category_indices; // Exit early if no denylist/allowlist. if (options.category_denylist_size() == 0 && @@ -239,7 +243,7 @@ absl::StatusOr> GetAllowOrDenyCategoryIndicesIfAny( absl::Status ConfigureScoreCalibrationIfAny( const ModelMetadataExtractor& metadata_extractor, int tensor_index, - ClassificationPostprocessingOptions* options) { + proto::ClassificationPostprocessingGraphOptions* options) { const auto* tensor_metadata = metadata_extractor.GetOutputTensorMetadata(tensor_index); if (tensor_metadata == nullptr) { @@ -280,10 +284,24 @@ absl::Status ConfigureScoreCalibrationIfAny( return absl::OkStatus(); } +void ConfigureClassificationAggregationCalculator( + const ModelMetadataExtractor& metadata_extractor, + ClassificationAggregationCalculatorOptions* options) { + auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata(); + if (output_tensors_metadata == nullptr) { + return; + } + for (const auto& metadata : *output_tensors_metadata) { + options->add_head_names(metadata->name()->str()); + } +} + +} // namespace + // Fills in the TensorsToClassificationCalculatorOptions based on the // classifier options and the (optional) output tensor metadata. absl::Status ConfigureTensorsToClassificationCalculator( - const ClassifierOptions& options, + const proto::ClassifierOptions& options, const ModelMetadataExtractor& metadata_extractor, int tensor_index, TensorsToClassificationCalculatorOptions* calculator_options) { const auto* tensor_metadata = @@ -331,24 +349,10 @@ absl::Status ConfigureTensorsToClassificationCalculator( return absl::OkStatus(); } -void ConfigureClassificationAggregationCalculator( - const ModelMetadataExtractor& metadata_extractor, - ClassificationAggregationCalculatorOptions* options) { - auto* output_tensors_metadata = metadata_extractor.GetOutputTensorMetadata(); - if (output_tensors_metadata == nullptr) { - return; - } - for (const auto& metadata : *output_tensors_metadata) { - options->add_head_names(metadata->name()->str()); - } -} - -} // namespace - -absl::Status ConfigureClassificationPostprocessing( +absl::Status ConfigureClassificationPostprocessingGraph( const ModelResources& model_resources, - const ClassifierOptions& classifier_options, - ClassificationPostprocessingOptions* options) { + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options) { MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options)); ASSIGN_OR_RETURN(const auto heads_properties, GetClassificationHeadsProperties(model_resources)); @@ -366,8 +370,8 @@ absl::Status ConfigureClassificationPostprocessing( return absl::OkStatus(); } -// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts -// raw tensors into ClassificationResult objects. +// A "ClassificationPostprocessingGraph" converts raw tensors into +// ClassificationResult objects. // - Accepts CPU input tensors. // // Inputs: @@ -381,10 +385,10 @@ absl::Status ConfigureClassificationPostprocessing( // CLASSIFICATION_RESULT - ClassificationResult // The output aggregated classification results. // -// The recommended way of using this subgraph is through the GraphBuilder API -// using the 'ConfigureClassificationPostprocessing()' function. See header file -// for more details. -class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { +// The recommended way of using this graph is through the GraphBuilder API +// using the 'ConfigureClassificationPostprocessingGraph()' function. See header +// file for more details. +class ClassificationPostprocessingGraph : public mediapipe::Subgraph { public: absl::StatusOr GetConfig( mediapipe::SubgraphContext* sc) override { @@ -392,7 +396,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { ASSIGN_OR_RETURN( auto classification_result_out, BuildClassificationPostprocessing( - sc->Options(), + sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); classification_result_out >> @@ -401,19 +405,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { } private: - // Adds an on-device classification postprocessing subgraph into the provided - // builder::Graph instance. The classification postprocessing subgraph takes + // Adds an on-device classification postprocessing graph into the provided + // builder::Graph instance. The classification postprocessing graph takes // tensors (std::vector) as input and returns one output // stream containing the output classification results (ClassificationResult). // - // options: the on-device ClassificationPostprocessingOptions. + // options: the on-device ClassificationPostprocessingGraphOptions. // tensors_in: (std::vector>) tensors to postprocess. // timestamps_in: (std::vector) optional collection of // timestamps that a single ClassificationResult should aggregate. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr> BuildClassificationPostprocessing( - const ClassificationPostprocessingOptions& options, + const proto::ClassificationPostprocessingGraphOptions& options, Source> tensors_in, Source> timestamps_in, Graph& graph) { const int num_heads = options.tensors_to_classifications_options_size(); @@ -504,9 +508,14 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { kClassificationResultTag)]; } }; -REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::components::ClassificationPostprocessingSubgraph); +// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly. +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::components::processors::ClassificationPostprocessingGraph); // NOLINT +// clang-format on + +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h new file mode 100644 index 000000000..be166982d --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -0,0 +1,74 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ + +#include "absl/status/status.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace processors { + +// Configures a ClassificationPostprocessingGraph using the provided model +// resources and ClassifierOptions. +// - Accepts CPU input tensors. +// +// Example usage: +// +// auto& postprocessing = +// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph"); +// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( +// model_resources, +// classifier_options, +// &preprocessing.GetOptions())); +// +// The resulting ClassificationPostprocessingGraph has the following I/O: +// Inputs: +// TENSORS - std::vector +// The output tensors of an InferenceCalculator. +// TIMESTAMPS - std::vector @Optional +// The collection of timestamps that a single ClassificationResult should +// aggregate. This is mostly useful for classifiers working on time series, +// e.g. audio or video classification. +// Outputs: +// CLASSIFICATION_RESULT - ClassificationResult +// The output aggregated classification results. +absl::Status ConfigureClassificationPostprocessingGraph( + const tasks::core::ModelResources& model_resources, + const proto::ClassifierOptions& classifier_options, + proto::ClassificationPostprocessingGraphOptions* options); + +// Utility function to fill in the TensorsToClassificationCalculatorOptions +// based on the classifier options and the (optional) output tensor metadata. +// This is meant to be used by other graphs that may also rely on this +// calculator. +absl::Status ConfigureTensorsToClassificationCalculator( + const proto::ClassifierOptions& options, + const metadata::ModelMetadataExtractor& metadata_extractor, + int tensor_index, + mediapipe::TensorsToClassificationCalculatorOptions* calculator_options); + +} // namespace processors +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_ diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc similarity index 88% rename from mediapipe/tasks/cc/components/classification_postprocessing_test.cc rename to mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index 67223050f..bb03e2530 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" #include #include @@ -42,9 +42,9 @@ limitations under the License. #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/util/label_map.pb.h" @@ -53,6 +53,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace components { +namespace processors { namespace { using ::mediapipe::api2::Input; @@ -60,7 +61,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::file::JoinPath; -using ::mediapipe::tasks::components::proto::ClassifierOptions; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; using ::testing::HasSubstr; using ::testing::proto::Approximately; @@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(0); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); @@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); options_in.add_category_denylist("bar"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); @@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("foo"); - ClassificationPostprocessingOptions options_out; - auto status = ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out); + proto::ClassificationPostprocessingGraphOptions options_out; + auto status = ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT( @@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_max_results(3); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.set_score_threshold(0.5); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); EXPECT_THAT(options_out, Approximately(EqualsProto( R"pb(score_calibration_options: [] @@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_allowlist("tench"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; options_in.add_category_denylist("background"); - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Clear label map and compare the rest of the options. options_out.mutable_tensors_to_classifications_options(0) @@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { auto model_resources, CreateModelResourcesForModel( kQuantizedImageClassifierWithDummyScoreCalibration)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label map size and two first elements. EXPECT_EQ( @@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata)); - ClassifierOptions options_in; + proto::ClassifierOptions options_in; - ClassificationPostprocessingOptions options_out; - MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, - options_in, &options_out)); + proto::ClassificationPostprocessingGraphOptions options_out; + MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph( + *model_resources, options_in, &options_out)); // Check label maps sizes and first two elements. EXPECT_EQ( options_out.tensors_to_classifications_options(0).label_items_size(), @@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { class PostprocessingTest : public tflite_shims::testing::Test { protected: absl::StatusOr BuildGraph( - absl::string_view model_name, const ClassifierOptions& options, + absl::string_view model_name, const proto::ClassifierOptions& options, bool connect_timestamps = false) { ASSIGN_OR_RETURN(auto model_resources, CreateModelResourcesForModel(model_name)); Graph graph; auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( *model_resources, options, - &postprocessing.GetOptions())); + &postprocessing + .GetOptions())); graph[Input>(kTensorsTag)].SetName(kTensorsName) >> postprocessing.In(kTensorsTag); if (connect_timestamps) { @@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test { TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); options.set_score_threshold(0.5); MP_ASSERT_OK_AND_ASSIGN( @@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { TEST_F(PostprocessingTest, SucceedsWithMetadata) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); @@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(3); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, @@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { TEST_F(PostprocessingTest, SucceedsWithTimestamps) { // Build graph. - ClassifierOptions options; + proto::ClassifierOptions options; options.set_max_results(2); MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, @@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { } } // namespace +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.cc b/mediapipe/tasks/cc/components/processors/classifier_options.cc similarity index 81% rename from mediapipe/tasks/cc/components/classifier_options.cc rename to mediapipe/tasks/cc/components/processors/classifier_options.cc index c54db5f88..349bb569d 100644 --- a/mediapipe/tasks/cc/components/classifier_options.cc +++ b/mediapipe/tasks/cc/components/processors/classifier_options.cc @@ -13,17 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/components/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* options) { - tasks::components::proto::ClassifierOptions options_proto; + proto::ClassifierOptions options_proto; options_proto.set_display_names_locale(options->display_names_locale); options_proto.set_max_results(options->max_results); options_proto.set_score_threshold(options->score_threshold); @@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( return options_proto; } +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/components/classifier_options.h b/mediapipe/tasks/cc/components/processors/classifier_options.h similarity index 83% rename from mediapipe/tasks/cc/components/classifier_options.h rename to mediapipe/tasks/cc/components/processors/classifier_options.h index e15bf5e69..189b42e60 100644 --- a/mediapipe/tasks/cc/components/classifier_options.h +++ b/mediapipe/tasks/cc/components/processors/classifier_options.h @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ -#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" namespace mediapipe { namespace tasks { namespace components { +namespace processors { // Classifier options for MediaPipe C++ classification Tasks. struct ClassifierOptions { @@ -49,11 +50,12 @@ struct ClassifierOptions { }; // Converts a ClassifierOptions to a ClassifierOptionsProto. -tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( +proto::ClassifierOptions ConvertClassifierOptionsToProto( ClassifierOptions* classifier_options); +} // namespace processors } // namespace components } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/components/processors/proto/BUILD b/mediapipe/tasks/cc/components/processors/proto/BUILD new file mode 100644 index 000000000..d7cbe47ff --- /dev/null +++ b/mediapipe/tasks/cc/components/processors/proto/BUILD @@ -0,0 +1,36 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "classifier_options_proto", + srcs = ["classifier_options.proto"], +) + +mediapipe_proto_library( + name = "classification_postprocessing_graph_options_proto", + srcs = ["classification_postprocessing_graph_options.proto"], + deps = [ + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", + "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto similarity index 91% rename from mediapipe/tasks/cc/components/classification_postprocessing_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto index 9b67e2f75..1de788eab 100644 --- a/mediapipe/tasks/cc/components/classification_postprocessing_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.proto @@ -15,16 +15,16 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components; +package mediapipe.tasks.components.processors.proto; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; -message ClassificationPostprocessingOptions { +message ClassificationPostprocessingGraphOptions { extend mediapipe.CalculatorOptions { - optional ClassificationPostprocessingOptions ext = 460416950; + optional ClassificationPostprocessingGraphOptions ext = 460416950; } // Optional mapping between output tensor index and corresponding score diff --git a/mediapipe/tasks/cc/components/proto/classifier_options.proto b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto similarity index 90% rename from mediapipe/tasks/cc/components/proto/classifier_options.proto rename to mediapipe/tasks/cc/components/processors/proto/classifier_options.proto index ea1491bb8..12ece7249 100644 --- a/mediapipe/tasks/cc/components/proto/classifier_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/classifier_options.proto @@ -15,7 +15,10 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.components.proto; +package mediapipe.tasks.components.processors.proto; + +option java_package = "com.google.mediapipe.tasks.components.processors.proto"; +option java_outer_classname = "ClassifierOptionsProto"; // Shared options used by all classification tasks. message ClassifierOptions { diff --git a/mediapipe/tasks/cc/components/proto/BUILD b/mediapipe/tasks/cc/components/proto/BUILD index 8c4dcdad9..c11d6f95a 100644 --- a/mediapipe/tasks/cc/components/proto/BUILD +++ b/mediapipe/tasks/cc/components/proto/BUILD @@ -23,11 +23,6 @@ mediapipe_proto_library( srcs = ["segmenter_options.proto"], ) -mediapipe_proto_library( - name = "classifier_options_proto", - srcs = ["classifier_options.proto"], -) - mediapipe_proto_library( name = "embedder_options_proto", srcs = ["embedder_options.proto"], diff --git a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto b/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto index c0c207543..926e3d7fb 100644 --- a/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.proto @@ -31,6 +31,8 @@ message TextPreprocessingGraphOptions { BERT_PREPROCESSOR = 1; // Used for the RegexPreprocessorCalculator. REGEX_PREPROCESSOR = 2; + // Used for the TextToTensorCalculator. + STRING_PREPROCESSOR = 3; } optional PreprocessorType preprocessor_type = 1; diff --git a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc b/mediapipe/tasks/cc/components/text_preprocessing_graph.cc index 2c4c1b866..6aad8fdd5 100644 --- a/mediapipe/tasks/cc/components/text_preprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/text_preprocessing_graph.cc @@ -65,6 +65,8 @@ absl::StatusOr GetCalculatorNameFromPreprocessorType( return "BertPreprocessorCalculator"; case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: return "RegexPreprocessorCalculator"; + case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: + return "TextToTensorCalculator"; } } @@ -91,11 +93,7 @@ GetPreprocessorType(const ModelResources& model_resources) { MediaPipeTasksStatus::kInvalidInputTensorTypeError); } if (all_string_tensors) { - // TODO: Support a TextToTensor calculator for string tensors. - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - "String tensors are not supported yet", - MediaPipeTasksStatus::kInvalidInputTensorTypeError); + return TextPreprocessingGraphOptions::STRING_PREPROCESSOR; } // Otherwise, all tensors should have type int32 @@ -185,10 +183,19 @@ absl::Status ConfigureTextPreprocessingSubgraph( TextPreprocessingGraphOptions::PreprocessorType preprocessor_type, GetPreprocessorType(model_resources)); options.set_preprocessor_type(preprocessor_type); - ASSIGN_OR_RETURN( - int max_seq_len, - GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); - options.set_max_seq_len(max_seq_len); + switch (preprocessor_type) { + case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: + case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { + break; + } + case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: + case TextPreprocessingGraphOptions::REGEX_PREPROCESSOR: { + ASSIGN_OR_RETURN( + int max_seq_len, + GetMaxSeqLen(*(*model_resources.GetTfLiteModel()->subgraphs())[0])); + options.set_max_seq_len(max_seq_len); + } + } return absl::OkStatus(); } @@ -236,7 +243,8 @@ class TextPreprocessingSubgraph : public mediapipe::Subgraph { GetCalculatorNameFromPreprocessorType(options.preprocessor_type())); auto& text_preprocessor = graph.AddNode(preprocessor_name); switch (options.preprocessor_type()) { - case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: { + case TextPreprocessingGraphOptions::UNSPECIFIED_PREPROCESSOR: + case TextPreprocessingGraphOptions::STRING_PREPROCESSOR: { break; } case TextPreprocessingGraphOptions::BERT_PREPROCESSOR: { diff --git a/mediapipe/tasks/cc/components/utils/BUILD b/mediapipe/tasks/cc/components/utils/BUILD index 0ec7ac945..d16e2fbc4 100644 --- a/mediapipe/tasks/cc/components/utils/BUILD +++ b/mediapipe/tasks/cc/components/utils/BUILD @@ -42,3 +42,16 @@ cc_test( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", ], ) + +cc_library( + name = "gate", + hdrs = ["gate.h"], + deps = [ + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:gate_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + ], +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/components/utils/gate.h b/mediapipe/tasks/cc/components/utils/gate.h new file mode 100644 index 000000000..139205fc5 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/gate.h @@ -0,0 +1,160 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ + +#include + +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { + +// Utility class that simplifies allowing (gating) multiple streams. +class AllowGate { + public: + AllowGate(api2::builder::Source allow, api2::builder::Graph& graph) + : node_(AddSourceGate(allow, graph)) {} + AllowGate(api2::builder::SideSource allow, api2::builder::Graph& graph) + : node_(AddSideSourceGate(allow, graph)) {} + + // Move-only + AllowGate(AllowGate&& allow_gate) = default; + AllowGate& operator=(AllowGate&& allow_gate) = default; + + template + api2::builder::Source Allow(api2::builder::Source source) { + source >> node_.In(index_); + return node_.Out(index_++).Cast(); + } + + private: + template + static api2::builder::GenericNode& AddSourceGate( + T allow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + allow >> gate_node.In("ALLOW"); + return gate_node; + } + + template + static api2::builder::GenericNode& AddSideSourceGate( + T allow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + allow >> gate_node.SideIn("ALLOW"); + return gate_node; + } + + api2::builder::GenericNode& node_; + int index_ = 0; +}; + +// Utility class that simplifies disallowing (gating) multiple streams. +class DisallowGate { + public: + DisallowGate(api2::builder::Source disallow, + api2::builder::Graph& graph) + : node_(AddSourceGate(disallow, graph)) {} + DisallowGate(api2::builder::SideSource disallow, + api2::builder::Graph& graph) + : node_(AddSideSourceGate(disallow, graph)) {} + + // Move-only + DisallowGate(DisallowGate&& disallow_gate) = default; + DisallowGate& operator=(DisallowGate&& disallow_gate) = default; + + template + api2::builder::Source Disallow(api2::builder::Source source) { + source >> node_.In(index_); + return node_.Out(index_++).Cast(); + } + + private: + template + static api2::builder::GenericNode& AddSourceGate( + T disallow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + auto& gate_node_opts = + gate_node.GetOptions(); + // Supposedly, the most popular configuration for MediaPipe Tasks team + // graphs. Hence, intentionally hard coded to catch and verify any other use + // case (should help to workout a common approach and have a recommended way + // of blocking streams). + gate_node_opts.set_empty_packets_as_allow(true); + disallow >> gate_node.In("DISALLOW"); + return gate_node; + } + + template + static api2::builder::GenericNode& AddSideSourceGate( + T disallow, api2::builder::Graph& graph) { + auto& gate_node = graph.AddNode("GateCalculator"); + auto& gate_node_opts = + gate_node.GetOptions(); + gate_node_opts.set_empty_packets_as_allow(true); + disallow >> gate_node.SideIn("DISALLOW"); + return gate_node; + } + + api2::builder::GenericNode& node_; + int index_ = 0; +}; + +// Updates graph to drop @value stream packet if corresponding @condition stream +// packet holds true. +template +api2::builder::Source DisallowIf(api2::builder::Source value, + api2::builder::Source condition, + api2::builder::Graph& graph) { + return DisallowGate(condition, graph).Disallow(value); +} + +// Updates graph to drop @value stream packet if corresponding @condition stream +// packet holds true. +template +api2::builder::Source DisallowIf(api2::builder::Source value, + api2::builder::SideSource condition, + api2::builder::Graph& graph) { + return DisallowGate(condition, graph).Disallow(value); +} + +// Updates graph to pass through @value stream packet if corresponding +// @allow stream packet holds true. +template +api2::builder::Source AllowIf(api2::builder::Source value, + api2::builder::Source allow, + api2::builder::Graph& graph) { + return AllowGate(allow, graph).Allow(value); +} + +// Updates graph to pass through @value stream packet if corresponding +// @allow side stream packet holds true. +template +api2::builder::Source AllowIf(api2::builder::Source value, + api2::builder::SideSource allow, + api2::builder::Graph& graph) { + return AllowGate(allow, graph).Allow(value); +} + +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_ diff --git a/mediapipe/tasks/cc/components/utils/gate_test.cc b/mediapipe/tasks/cc/components/utils/gate_test.cc new file mode 100644 index 000000000..7fdca48e7 --- /dev/null +++ b/mediapipe/tasks/cc/components/utils/gate_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/utils/gate.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_graph.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe { +namespace tasks { +namespace components { +namespace utils { +namespace { + +using ::mediapipe::api2::builder::SideSource; +using ::mediapipe::api2::builder::Source; + +TEST(DisallowGate, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source condition = graph.In("CONDITION").Cast(); + Source value1 = graph.In("VALUE_1").Cast(); + Source value2 = graph.In("VALUE_2").Cast(); + Source value3 = graph.In("VALUE_3").Cast(); + + DisallowGate gate(condition, graph); + gate.Disallow(value1).SetName("gated_stream1"); + gate.Disallow(value2).SetName("gated_stream2"); + gate.Disallow(value3).SetName("gated_stream3"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "__stream_2" + input_stream: "__stream_3" + input_stream: "DISALLOW:__stream_0" + output_stream: "gated_stream1" + output_stream: "gated_stream2" + output_stream: "gated_stream3" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE_1:__stream_1" + input_stream: "VALUE_2:__stream_2" + input_stream: "VALUE_3:__stream_3" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DisallowIf, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + Source condition = graph.In("CONDITION").Cast(); + + auto gated_stream = DisallowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "DISALLOW:__stream_0" + output_stream: "gated_stream" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(DisallowIf, VerifyConfigWithSideCondition) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + SideSource condition = graph.SideIn("CONDITION").Cast(); + + auto gated_stream = DisallowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_0" + output_stream: "gated_stream" + input_side_packet: "DISALLOW:__side_packet_1" + options { + [mediapipe.GateCalculatorOptions.ext] { + empty_packets_as_allow: true + } + } + } + input_stream: "VALUE:__stream_0" + input_side_packet: "CONDITION:__side_packet_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowGate, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source condition = graph.In("CONDITION").Cast(); + Source value1 = graph.In("VALUE_1").Cast(); + Source value2 = graph.In("VALUE_2").Cast(); + Source value3 = graph.In("VALUE_3").Cast(); + + AllowGate gate(condition, graph); + gate.Allow(value1).SetName("gated_stream1"); + gate.Allow(value2).SetName("gated_stream2"); + gate.Allow(value3).SetName("gated_stream3"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "__stream_2" + input_stream: "__stream_3" + input_stream: "ALLOW:__stream_0" + output_stream: "gated_stream1" + output_stream: "gated_stream2" + output_stream: "gated_stream3" + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE_1:__stream_1" + input_stream: "VALUE_2:__stream_2" + input_stream: "VALUE_3:__stream_3" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowIf, VerifyConfig) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + Source condition = graph.In("CONDITION").Cast(); + + auto gated_stream = AllowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_1" + input_stream: "ALLOW:__stream_0" + output_stream: "gated_stream" + } + input_stream: "CONDITION:__stream_0" + input_stream: "VALUE:__stream_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +TEST(AllowIf, VerifyConfigWithSideConition) { + mediapipe::api2::builder::Graph graph; + + Source value = graph.In("VALUE").Cast(); + SideSource condition = graph.SideIn("CONDITION").Cast(); + + auto gated_stream = AllowIf(value, condition, graph); + gated_stream.SetName("gated_stream"); + + EXPECT_THAT(graph.GetConfig(), + testing::EqualsProto( + mediapipe::ParseTextProtoOrDie(R"pb( + node { + calculator: "GateCalculator" + input_stream: "__stream_0" + output_stream: "gated_stream" + input_side_packet: "ALLOW:__side_packet_1" + } + input_stream: "VALUE:__stream_0" + input_side_packet: "CONDITION:__side_packet_1" + )pb"))); + + CalculatorGraph calcualtor_graph; + MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig())); +} + +} // namespace +} // namespace utils +} // namespace components +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 38030c525..291dd29fe 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -23,6 +23,7 @@ cc_library( srcs = ["base_options.cc"], hdrs = ["base_options.h"], deps = [ + ":mediapipe_builtin_op_resolver", "//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", @@ -50,6 +51,21 @@ cc_library( ], ) +cc_library( + name = "mediapipe_builtin_op_resolver", + srcs = ["mediapipe_builtin_op_resolver.cc"], + hdrs = ["mediapipe_builtin_op_resolver.h"], + deps = [ + "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", + "//mediapipe/util/tflite/operations:max_pool_argmax", + "//mediapipe/util/tflite/operations:max_unpooling", + "//mediapipe/util/tflite/operations:transform_landmarks", + "//mediapipe/util/tflite/operations:transform_tensor_bilinear", + "//mediapipe/util/tflite/operations:transpose_conv_bias", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) + # TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator # supports TFLite-in-GMSCore. cc_library( @@ -57,6 +73,7 @@ cc_library( srcs = ["model_task_graph.cc"], hdrs = ["model_task_graph.h"], deps = [ + ":model_asset_bundle_resources", ":model_resources", ":model_resources_cache", ":model_resources_calculator", @@ -122,6 +139,7 @@ cc_test_with_tflite( "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", ], deps = [ + ":utils", "//mediapipe/framework/api2:packet", "//mediapipe/framework/port:gtest_main", "//mediapipe/tasks/cc:common", @@ -146,6 +164,7 @@ cc_library_with_tflite( "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", ], deps = [ + ":model_asset_bundle_resources", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/api2:packet", "//mediapipe/tasks/cc:common", @@ -298,3 +317,41 @@ cc_library( "@flatbuffers//:runtime_cc", ], ) + +cc_library( + name = "model_asset_bundle_resources", + srcs = ["model_asset_bundle_resources.cc"], + hdrs = ["model_asset_bundle_resources.h"], + deps = [ + ":external_file_handler", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/util:resource_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "model_asset_bundle_resources_test", + srcs = ["model_asset_bundle_resources_test.cc"], + data = [ + "//mediapipe/tasks/testdata/core:test_models", + ], + deps = [ + ":model_asset_bundle_resources", + ":model_resources", + ":utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) diff --git a/mediapipe/tasks/cc/core/base_options.cc b/mediapipe/tasks/cc/core/base_options.cc index ec85ea753..0e3bb7401 100644 --- a/mediapipe/tasks/cc/core/base_options.cc +++ b/mediapipe/tasks/cc/core/base_options.cc @@ -52,7 +52,7 @@ proto::BaseOptions ConvertBaseOptionsToProto(BaseOptions* base_options) { } switch (base_options->delegate) { case BaseOptions::Delegate::CPU: - base_options_proto.mutable_acceleration()->mutable_xnnpack(); + base_options_proto.mutable_acceleration()->mutable_tflite(); break; case BaseOptions::Delegate::GPU: base_options_proto.mutable_acceleration()->mutable_gpu(); diff --git a/mediapipe/tasks/cc/core/base_options.h b/mediapipe/tasks/cc/core/base_options.h index 67a03385b..4717ea50e 100644 --- a/mediapipe/tasks/cc/core/base_options.h +++ b/mediapipe/tasks/cc/core/base_options.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" @@ -63,7 +64,7 @@ struct BaseOptions { // A non-default OpResolver to support custom Ops or specify a subset of // built-in Ops. std::unique_ptr op_resolver = - absl::make_unique(); + absl::make_unique(); }; // Converts a BaseOptions to a BaseOptionsProto. diff --git a/mediapipe/tasks/cc/core/external_file_handler.cc b/mediapipe/tasks/cc/core/external_file_handler.cc index 8a219bb80..33dfeca0b 100644 --- a/mediapipe/tasks/cc/core/external_file_handler.cc +++ b/mediapipe/tasks/cc/core/external_file_handler.cc @@ -92,13 +92,26 @@ absl::Status ExternalFileHandler::MapExternalFile() { #else if (!external_file_.file_content().empty()) { return absl::OkStatus(); + } else if (external_file_.has_file_pointer_meta()) { + if (external_file_.file_pointer_meta().pointer() == 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "Need to set the file pointer in external_file.file_pointer_meta."); + } + if (external_file_.file_pointer_meta().length() <= 0) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "The length of the file in external_file.file_pointer_meta should be " + "positive."); + } + return absl::OkStatus(); } if (external_file_.file_name().empty() && !external_file_.has_file_descriptor_meta()) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - "ExternalFile must specify at least one of 'file_content', 'file_name' " - "or 'file_descriptor_meta'.", + "ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.", MediaPipeTasksStatus::kInvalidArgumentError); } // Obtain file descriptor, offset and size. @@ -196,6 +209,11 @@ absl::Status ExternalFileHandler::MapExternalFile() { absl::string_view ExternalFileHandler::GetFileContent() { if (!external_file_.file_content().empty()) { return external_file_.file_content(); + } else if (external_file_.has_file_pointer_meta()) { + void* ptr = + reinterpret_cast(external_file_.file_pointer_meta().pointer()); + return absl::string_view(static_cast(ptr), + external_file_.file_pointer_meta().length()); } else { return absl::string_view(static_cast(buffer_) + buffer_offset_ - buffer_aligned_offset_, diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc similarity index 87% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index cd3b5690f..62898a005 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h" @@ -21,14 +21,11 @@ limitations under the License. #include "mediapipe/util/tflite/operations/transform_landmarks.h" #include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h" #include "mediapipe/util/tflite/operations/transpose_conv_bias.h" -#include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() - : BuiltinOpResolver() { +namespace core { +MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { AddCustom("MaxPoolingWithArgmax2D", mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); AddCustom("MaxUnpooling2D", @@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver() mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), /*version=*/2); } - -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h similarity index 65% rename from mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h rename to mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h index a0538a674..a7c28aa71 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h @@ -13,25 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ +#define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ #include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace tasks { -namespace vision { - -class SelfieSegmentationModelOpResolver +namespace core { +class MediaPipeBuiltinOpResolver : public tflite::ops::builtin::BuiltinOpResolver { public: - SelfieSegmentationModelOpResolver(); - SelfieSegmentationModelOpResolver( - const SelfieSegmentationModelOpResolver& r) = delete; + MediaPipeBuiltinOpResolver(); + MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete; }; -} // namespace vision +} // namespace core } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ +#endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_ diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc new file mode 100644 index 000000000..5867be49b --- /dev/null +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc @@ -0,0 +1,107 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/util/resource_util.h" + +namespace mediapipe { +namespace tasks { +namespace core { + +namespace { +using ::absl::StatusCode; +} // namespace + +ModelAssetBundleResources::ModelAssetBundleResources( + const std::string& tag, + std::unique_ptr model_asset_bundle_file) + : tag_(tag), model_asset_bundle_file_(std::move(model_asset_bundle_file)) {} + +/* static */ +absl::StatusOr> +ModelAssetBundleResources::Create( + const std::string& tag, + std::unique_ptr model_asset_bundle_file) { + if (model_asset_bundle_file == nullptr) { + return CreateStatusWithPayload( + StatusCode::kInvalidArgument, + "The model asset bundle file proto cannot be nullptr.", + MediaPipeTasksStatus::kInvalidArgumentError); + } + auto model_bundle_resources = absl::WrapUnique( + new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file))); + MP_RETURN_IF_ERROR( + model_bundle_resources->ExtractModelFilesFromExternalFileProto()); + return model_bundle_resources; +} + +absl::Status +ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() { + if (model_asset_bundle_file_->has_file_name()) { + // If the model asset bundle file name is a relative path, searches the file + // in a platform-specific location and returns the absolute path on success. + ASSIGN_OR_RETURN( + std::string path_to_resource, + mediapipe::PathToResourceAsFile(model_asset_bundle_file_->file_name())); + model_asset_bundle_file_->set_file_name(path_to_resource); + } + ASSIGN_OR_RETURN(model_asset_bundle_file_handler_, + ExternalFileHandler::CreateFromExternalFile( + model_asset_bundle_file_.get())); + const char* buffer_data = + model_asset_bundle_file_handler_->GetFileContent().data(); + size_t buffer_size = + model_asset_bundle_file_handler_->GetFileContent().size(); + return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, + &model_files_); +} + +absl::StatusOr ModelAssetBundleResources::GetModelFile( + const std::string& filename) const { + auto it = model_files_.find(filename); + if (it == model_files_.end()) { + auto model_files = ListModelFiles(); + std::string all_model_files = + absl::StrJoin(model_files.begin(), model_files.end(), ", "); + + return CreateStatusWithPayload( + StatusCode::kNotFound, + absl::StrFormat("No model file with name: %s. All model files in the " + "model asset bundle are: %s.", + filename, all_model_files), + MediaPipeTasksStatus::kFileNotFoundError); + } + return it->second; +} + +std::vector ModelAssetBundleResources::ListModelFiles() const { + std::vector model_names; + for (const auto& [model_name, _] : model_files_) { + model_names.push_back(model_name); + } + return model_names; +} + +} // namespace core +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.h b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h new file mode 100644 index 000000000..61474d3ad --- /dev/null +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h @@ -0,0 +1,92 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_ +#define MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/core/external_file_handler.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" + +namespace mediapipe { +namespace tasks { +namespace core { + +// The mediapipe task model asset bundle resources class. +// A ModelAssetBundleResources object, created from an external file proto, +// contains model asset bundle related resources and the method to extract the +// tflite models or model asset bundles for the mediapipe sub-tasks. As the +// resources are owned by the ModelAssetBundleResources object +// callers must keep ModelAssetBundleResources alive while using any of the +// resources. +class ModelAssetBundleResources { + public: + // Takes the ownership of the provided ExternalFile proto and creates + // ModelAssetBundleResources from the proto. A non-empty tag + // must be set if the ModelAssetBundleResources will be used through + // ModelResourcesCacheService. + static absl::StatusOr> Create( + const std::string& tag, + std::unique_ptr model_asset_bundle_file); + + // ModelResources is neither copyable nor movable. + ModelAssetBundleResources(const ModelAssetBundleResources&) = delete; + ModelAssetBundleResources& operator=(const ModelAssetBundleResources&) = + delete; + + // Returns the model asset bundle resources tag. + std::string GetTag() const { return tag_; } + + // Gets the contents of the model file (either tflite model file or model + // bundle file) with the provided name. An error is returned if there is no + // such model file. + absl::StatusOr GetModelFile( + const std::string& filename) const; + + // Lists all the model file names in the model asset model. + std::vector ListModelFiles() const; + + private: + // Constructor. + ModelAssetBundleResources( + const std::string& tag, + std::unique_ptr model_asset_bundle_file); + + // Extracts the model files (either tflite model file or model bundle file) + // from the external file proto. + absl::Status ExtractModelFilesFromExternalFileProto(); + + // The model asset bundle resources tag. + const std::string tag_; + + // The model asset bundle file. + std::unique_ptr model_asset_bundle_file_; + + // The ExternalFileHandler for the model asset bundle. + std::unique_ptr model_asset_bundle_file_handler_; + + // The model files bundled in model asset bundle, as a map with the filename + // (corresponding to a basename, e.g. "hand_detector.tflite") as key and + // a pointer to the file contents as value. Each model file can be either + // a TFLite model file or a model bundle file for sub-task. + absl::flat_hash_map model_files_; +}; + +} // namespace core +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_CORE_MODEL_ASSET_BUNDLE_RESOURCES_H_ diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc new file mode 100644 index 000000000..bcf88713c --- /dev/null +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc @@ -0,0 +1,229 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" + +namespace mediapipe { +namespace tasks { +namespace core { +namespace { + +constexpr char kTestModelResourcesTag[] = "test_model_asset_resources"; + +constexpr char kTestModelBundleResourcesTag[] = + "test_model_asset_bundle_resources"; + +// Models files in dummy_gesture_recognizer.task: +// gesture_recognizer.task +// dummy_gesture_recognizer.tflite +// dummy_hand_landmarker.task +// dummy_hand_detector.tflite +// dummy_hand_landmarker.tflite +constexpr char kTestModelBundlePath[] = + "mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task"; + +constexpr char kInvalidTestModelBundlePath[] = + "mediapipe/tasks/testdata/core/i_do_not_exist.task"; + +} // namespace + +TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) { + auto model_file = std::make_unique(); + model_file->set_file_content(LoadBinaryContent(kTestModelBundlePath)); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") + .status()); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + .status()); +} + +TEST(ModelAssetBundleResourcesTest, CreateFromFile) { + auto model_file = std::make_unique(); + model_file->set_file_name(kTestModelBundlePath); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") + .status()); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + .status()); +} + +TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) { + const int model_file_descriptor = open(kTestModelBundlePath, O_RDONLY); + auto model_file = std::make_unique(); + model_file->mutable_file_descriptor_meta()->set_fd(model_file_descriptor); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") + .status()); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + .status()); +} + +TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) { + auto file_content = LoadBinaryContent(kTestModelBundlePath); + auto model_file = std::make_unique(); + metadata::SetExternalFile(file_content, model_file.get()); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") + .status()); + MP_EXPECT_OK( + model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + .status()); +} + +TEST(ModelAssetBundleResourcesTest, CreateFromInvalidFile) { + auto model_file = std::make_unique(); + model_file->set_file_name(kInvalidTestModelBundlePath); + auto status_or_model_bundle_resources = ModelAssetBundleResources::Create( + kTestModelBundleResourcesTag, std::move(model_file)); + + EXPECT_EQ(status_or_model_bundle_resources.status().code(), + absl::StatusCode::kNotFound); + EXPECT_THAT(status_or_model_bundle_resources.status().message(), + testing::HasSubstr("Unable to open file")); + EXPECT_THAT(status_or_model_bundle_resources.status().GetPayload( + kMediaPipeTasksPayload), + testing::Optional(absl::Cord( + absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError)))); +} + +TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) { + // Creates top-level model asset bundle resources. + auto model_file = std::make_unique(); + model_file->set_file_name(kTestModelBundlePath); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + auto status_or_model_bundle_file = + model_bundle_resources->GetModelFile("dummy_hand_landmarker.task"); + MP_EXPECT_OK(status_or_model_bundle_file.status()); + + // Creates sub-task model asset bundle resources. + auto hand_landmaker_model_file = std::make_unique(); + metadata::SetExternalFile(status_or_model_bundle_file.value(), + hand_landmaker_model_file.get()); + MP_ASSERT_OK_AND_ASSIGN( + auto hand_landmaker_model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(hand_landmaker_model_file))); + MP_EXPECT_OK(hand_landmaker_model_bundle_resources + ->GetModelFile("dummy_hand_detector.tflite") + .status()); + MP_EXPECT_OK(hand_landmaker_model_bundle_resources + ->GetModelFile("dummy_hand_landmarker.tflite") + .status()); +} + +TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) { + // Creates top-level model asset bundle resources. + auto model_file = std::make_unique(); + model_file->set_file_name(kTestModelBundlePath); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + auto status_or_model_bundle_file = + model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite"); + MP_EXPECT_OK(status_or_model_bundle_file.status()); + + // Verify tflite model works. + auto hand_detector_model_file = std::make_unique(); + metadata::SetExternalFile(status_or_model_bundle_file.value(), + hand_detector_model_file.get()); + MP_ASSERT_OK_AND_ASSIGN( + auto hand_detector_model_resources, + ModelResources::Create(kTestModelResourcesTag, + std::move(hand_detector_model_file))); + Packet model_packet = hand_detector_model_resources->GetModelPacket(); + ASSERT_FALSE(model_packet.IsEmpty()); + MP_ASSERT_OK(model_packet.ValidateAsType()); + EXPECT_TRUE(model_packet.Get()->initialized()); +} + +TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) { + // Creates top-level model asset bundle resources. + auto model_file = std::make_unique(); + model_file->set_file_name(kTestModelBundlePath); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + auto status = model_bundle_resources->GetModelFile("not_found.task").status(); + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); + EXPECT_THAT(status.message(), + testing::HasSubstr( + "No model file with name: not_found.task. All model files in " + "the model asset bundle are: ")); + EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), + testing::Optional(absl::Cord( + absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError)))); +} + +TEST(ModelAssetBundleResourcesTest, ListModelFiles) { + // Creates top-level model asset bundle resources. + auto model_file = std::make_unique(); + model_file->set_file_name(kTestModelBundlePath); + MP_ASSERT_OK_AND_ASSIGN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, + std::move(model_file))); + auto model_files = model_bundle_resources->ListModelFiles(); + std::vector expected_model_files = { + "dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"}; + std::sort(model_files.begin(), model_files.end()); + EXPECT_THAT(expected_model_files, testing::ElementsAreArray(model_files)); +} + +} // namespace core +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/core/model_resources_cache.cc b/mediapipe/tasks/cc/core/model_resources_cache.cc index 216962bcf..affcb6dea 100644 --- a/mediapipe/tasks/cc/core/model_resources_cache.cc +++ b/mediapipe/tasks/cc/core/model_resources_cache.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -39,12 +40,16 @@ ModelResourcesCache::ModelResourcesCache( graph_op_resolver_packet_ = api2::PacketAdopting(std::move(graph_op_resolver)); } -}; +} bool ModelResourcesCache::Exists(const std::string& tag) const { return model_resources_collection_.contains(tag); } +bool ModelResourcesCache::ModelAssetBundleExists(const std::string& tag) const { + return model_asset_bundle_resources_collection_.contains(tag); +} + absl::Status ModelResourcesCache::AddModelResources( std::unique_ptr model_resources) { if (model_resources == nullptr) { @@ -94,6 +99,62 @@ absl::StatusOr ModelResourcesCache::GetModelResources( return model_resources_collection_.at(tag).get(); } +absl::Status ModelResourcesCache::AddModelAssetBundleResources( + std::unique_ptr model_asset_bundle_resources) { + if (model_asset_bundle_resources == nullptr) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ModelAssetBundleResources object is null.", + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + const std::string& tag = model_asset_bundle_resources->GetTag(); + if (tag.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ModelAssetBundleResources must have a non-empty tag.", + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + if (ModelAssetBundleExists(tag)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute( + "ModelAssetBundleResources with tag \"$0\" already exists.", tag), + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + model_asset_bundle_resources_collection_.emplace( + tag, std::move(model_asset_bundle_resources)); + return absl::OkStatus(); +} + +absl::Status ModelResourcesCache::AddModelAssetBundleResourcesCollection( + std::vector>& + model_asset_bundle_resources_collection) { + for (auto& model_bundle_resources : model_asset_bundle_resources_collection) { + MP_RETURN_IF_ERROR( + AddModelAssetBundleResources(std::move(model_bundle_resources))); + } + return absl::OkStatus(); +} + +absl::StatusOr +ModelResourcesCache::GetModelAssetBundleResources( + const std::string& tag) const { + if (tag.empty()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "ModelAssetBundleResources must be retrieved with a non-empty tag.", + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + if (!ModelAssetBundleExists(tag)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::Substitute( + "ModelAssetBundleResources with tag \"$0\" does not exist.", tag), + MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError); + } + return model_asset_bundle_resources_collection_.at(tag).get(); +} + absl::StatusOr> ModelResourcesCache::GetGraphOpResolverPacket() const { if (graph_op_resolver_packet_.IsEmpty()) { diff --git a/mediapipe/tasks/cc/core/model_resources_cache.h b/mediapipe/tasks/cc/core/model_resources_cache.h index 044ef36b7..32909f93d 100644 --- a/mediapipe/tasks/cc/core/model_resources_cache.h +++ b/mediapipe/tasks/cc/core/model_resources_cache.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -46,6 +47,10 @@ class ModelResourcesCache { // Returns whether the tag exists in the model resources cache. bool Exists(const std::string& tag) const; + // Returns whether the tag of the model asset bundle exists in the model + // resources cache. + bool ModelAssetBundleExists(const std::string& tag) const; + // Adds a ModelResources object into the cache. // The tag of the ModelResources must be unique; the ownership of the // ModelResources will be transferred into the cache. @@ -62,6 +67,23 @@ class ModelResourcesCache { absl::StatusOr GetModelResources( const std::string& tag) const; + // Adds a ModelAssetBundleResources object into the cache. + // The tag of the ModelAssetBundleResources must be unique; the ownership of + // the ModelAssetBundleResources will be transferred into the cache. + absl::Status AddModelAssetBundleResources( + std::unique_ptr model_asset_bundle_resources); + + // Adds a collection of the ModelAssetBundleResources objects into the cache. + // The tag of the each ModelAssetBundleResources must be unique; the ownership + // of every ModelAssetBundleResources will be transferred into the cache. + absl::Status AddModelAssetBundleResourcesCollection( + std::vector>& + model_asset_bundle_resources_collection); + + // Retrieves a const ModelAssetBundleResources pointer by the unique tag. + absl::StatusOr GetModelAssetBundleResources( + const std::string& tag) const; + // Retrieves the graph op resolver packet. absl::StatusOr> GetGraphOpResolverPacket() const; @@ -73,6 +95,11 @@ class ModelResourcesCache { // A collection of ModelResources objects for the models in the graph. absl::flat_hash_map> model_resources_collection_; + + // A collection of ModelAssetBundleResources objects for the model bundles in + // the graph. + absl::flat_hash_map> + model_asset_bundle_resources_collection_; }; // Global service for mediapipe task model resources cache. diff --git a/mediapipe/tasks/cc/core/model_resources_test.cc b/mediapipe/tasks/cc/core/model_resources_test.cc index 0b13c6daa..de480c5a4 100644 --- a/mediapipe/tasks/cc/core/model_resources_test.cc +++ b/mediapipe/tasks/cc/core/model_resources_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -88,16 +89,6 @@ constexpr char kCorruptedModelPath[] = "mediapipe/tasks/testdata/core/" "corrupted_mobilenet_v1_0.25_224_1_default_1.tflite"; -std::string LoadBinaryContent(const char* filename) { - std::ifstream input_file(filename, std::ios::binary | std::ios::ate); - // Find buffer size from input file, and load the buffer. - size_t buffer_size = input_file.tellg(); - std::string buffer(buffer_size, '\0'); - input_file.seekg(0, std::ios::beg); - input_file.read(const_cast(buffer.c_str()), buffer_size); - return buffer; -} - void AssertStatusHasMediaPipeTasksStatusCode( absl::Status status, MediaPipeTasksStatus mediapipe_tasks_code) { EXPECT_THAT( diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index c6bc8f69b..66434483b 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" @@ -70,6 +71,17 @@ std::string CreateModelResourcesTag(const CalculatorGraphConfig::Node& node) { node_type); } +std::string CreateModelAssetBundleResourcesTag( + const CalculatorGraphConfig::Node& node) { + std::vector names = absl::StrSplit(node.name(), "__"); + std::string node_type = node.calculator(); + std::replace(node_type.begin(), node_type.end(), '.', '_'); + absl::AsciiStrToLower(&node_type); + return absl::StrFormat("%s_%s_model_asset_bundle_resources", + names.back().empty() ? "unnamed" : names.back(), + node_type); +} + } // namespace // Defines the mediapipe task inference unit as a MediaPipe subgraph that @@ -122,6 +134,9 @@ class InferenceSubgraph : public Subgraph { case Acceleration::kGpu: delegate.mutable_gpu()->CopyFrom(acceleration.gpu()); break; + case Acceleration::kTflite: + delegate.mutable_tflite()->CopyFrom(acceleration.tflite()); + break; case Acceleration::DELEGATE_NOT_SET: // Deafult inference calculator setting. break; @@ -141,21 +156,24 @@ absl::StatusOr ModelTaskGraph::GetConfig( } absl::StatusOr ModelTaskGraph::CreateModelResources( - SubgraphContext* sc, std::unique_ptr external_file) { + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix) { auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); if (!model_resources_cache_service.IsAvailable()) { - ASSIGN_OR_RETURN(local_model_resources_, + ASSIGN_OR_RETURN(auto local_model_resource, ModelResources::Create("", std::move(external_file))); LOG(WARNING) << "A local ModelResources object is created. Please consider using " "ModelResourcesCacheService to cache the created ModelResources " "object in the CalculatorGraph."; - return local_model_resources_.get(); + local_model_resources_.push_back(std::move(local_model_resource)); + return local_model_resources_.back().get(); } ASSIGN_OR_RETURN( auto op_resolver_packet, model_resources_cache_service.GetObject().GetGraphOpResolverPacket()); - const std::string tag = CreateModelResourcesTag(sc->OriginalNode()); + const std::string tag = + absl::StrCat(CreateModelResourcesTag(sc->OriginalNode()), tag_suffix); ASSIGN_OR_RETURN(auto model_resources, ModelResources::Create(tag, std::move(external_file), op_resolver_packet)); @@ -165,6 +183,41 @@ absl::StatusOr ModelTaskGraph::CreateModelResources( return model_resources_cache_service.GetObject().GetModelResources(tag); } +absl::StatusOr +ModelTaskGraph::CreateModelAssetBundleResources( + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix) { + auto model_resources_cache_service = sc->Service(kModelResourcesCacheService); + bool has_file_pointer_meta = external_file->has_file_pointer_meta(); + // if external file is set by file pointer, no need to add the model asset + // bundle resources into the model resources service since the memory is + // not owned by this model asset bundle resources. + if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) { + ASSIGN_OR_RETURN( + auto local_model_asset_bundle_resource, + ModelAssetBundleResources::Create("", std::move(external_file))); + if (!has_file_pointer_meta) { + LOG(WARNING) + << "A local ModelResources object is created. Please consider using " + "ModelResourcesCacheService to cache the created ModelResources " + "object in the CalculatorGraph."; + } + local_model_asset_bundle_resources_.push_back( + std::move(local_model_asset_bundle_resource)); + return local_model_asset_bundle_resources_.back().get(); + } + const std::string tag = absl::StrCat( + CreateModelAssetBundleResourcesTag(sc->OriginalNode()), tag_suffix); + ASSIGN_OR_RETURN( + auto model_bundle_resources, + ModelAssetBundleResources::Create(tag, std::move(external_file))); + MP_RETURN_IF_ERROR( + model_resources_cache_service.GetObject().AddModelAssetBundleResources( + std::move(model_bundle_resources))); + return model_resources_cache_service.GetObject().GetModelAssetBundleResources( + tag); +} + GenericNode& ModelTaskGraph::AddInference( const ModelResources& model_resources, const proto::Acceleration& acceleration, Graph& graph) const { @@ -177,9 +230,9 @@ GenericNode& ModelTaskGraph::AddInference( ->CopyFrom(acceleration); // When the model resources tag is available, the ModelResourcesCalculator // will retrieve the cached model resources from the graph service by tag. - // Otherwise, provides the exteranal file and asks the + // Otherwise, provides the external file and asks the // ModelResourcesCalculator to create a local model resources in its - // Calcualtor::Open(). + // Calculator::Open(). if (!model_resources.GetTag().empty()) { inference_subgraph_opts.set_model_resources_tag(model_resources.GetTag()); } else { diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 36016cb89..50dcc903b 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -27,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/acceleration.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" @@ -74,9 +76,48 @@ class ModelTaskGraph : public Subgraph { // construction stage. Note that the external file contents will be moved // into the model resources object on creation. The returned model resources // pointer will provide graph authors with the access to the metadata - // extractor and the tflite model. + // extractor and the tflite model. When the model resources graph service is + // available, a tag is generated internally asscoiated with the created model + // resource. If more than one model resources are created in a graph, the + // model resources graph service add the tag_suffix to support multiple + // resources. absl::StatusOr CreateModelResources( - SubgraphContext* sc, std::unique_ptr external_file); + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix = ""); + + // If the model resources graph service is available, creates a model asset + // bundle resources object from the subgraph context, and caches the created + // model asset bundle resources into the model resources graph service on + // success. Otherwise, creates a local model asset bundle resources object + // that can only be used in the graph construction stage. The returned model + // resources pointer will provide graph authors with the access to extracted + // model files. + template + absl::StatusOr + CreateModelAssetBundleResources(SubgraphContext* sc) { + auto external_file = std::make_unique(); + external_file->Swap(sc->MutableOptions() + ->mutable_base_options() + ->mutable_model_asset()); + return CreateModelAssetBundleResources(sc, std::move(external_file)); + } + + // If the model resources graph service is available, creates a model asset + // bundle resources object from the subgraph context, and caches the created + // model asset bundle resources into the model resources graph service on + // success. Otherwise, creates a local model asset bundle resources object + // that can only be used in the graph construction stage. Note that the + // external file contents will be moved into the model asset bundle resources + // object on creation. The returned model asset bundle resources pointer will + // provide graph authors with the access to extracted model files. When the + // model resources graph service is available, a tag is generated internally + // asscoiated with the created model asset bundle resource. If more than one + // model asset bundle resources are created in a graph, the model resources + // graph service add the tag_suffix to support multiple resources. + absl::StatusOr + CreateModelAssetBundleResources( + SubgraphContext* sc, std::unique_ptr external_file, + const std::string tag_suffix = ""); // Inserts a mediapipe task inference subgraph into the provided // GraphBuilder. The returned node provides the following interfaces to the @@ -94,7 +135,10 @@ class ModelTaskGraph : public Subgraph { api2::builder::Graph& graph) const; private: - std::unique_ptr local_model_resources_; + std::vector> local_model_resources_; + + std::vector> + local_model_asset_bundle_resources_; }; } // namespace core diff --git a/mediapipe/tasks/cc/core/proto/acceleration.proto b/mediapipe/tasks/cc/core/proto/acceleration.proto index a0575a5d5..bdfaff4d2 100644 --- a/mediapipe/tasks/cc/core/proto/acceleration.proto +++ b/mediapipe/tasks/cc/core/proto/acceleration.proto @@ -32,5 +32,6 @@ message Acceleration { oneof delegate { mediapipe.InferenceCalculatorOptions.Delegate.Xnnpack xnnpack = 1; mediapipe.InferenceCalculatorOptions.Delegate.Gpu gpu = 2; + mediapipe.InferenceCalculatorOptions.Delegate.TfLite tflite = 4; } } diff --git a/mediapipe/tasks/cc/core/proto/external_file.proto b/mediapipe/tasks/cc/core/proto/external_file.proto index af4a11697..3147a2224 100644 --- a/mediapipe/tasks/cc/core/proto/external_file.proto +++ b/mediapipe/tasks/cc/core/proto/external_file.proto @@ -26,10 +26,11 @@ option java_outer_classname = "ExternalFileProto"; // (1) file contents loaded in `file_content`. // (2) file path in `file_name`. // (3) file descriptor through `file_descriptor_meta` as returned by open(2). +// (4) file pointer and length in memory through `file_pointer_meta`. // // If more than one field of these fields is provided, they are used in this // precedence order. -// Next id: 4 +// Next id: 5 message ExternalFile { // The file contents as a byte array. optional bytes file_content = 1; @@ -40,6 +41,13 @@ message ExternalFile { // The file descriptor to a file opened with open(2), with optional additional // offset and length information. optional FileDescriptorMeta file_descriptor_meta = 3; + + // The pointer points to location of a file in memory. Use the util method, + // `SetExternalFile` in [1], to configure `file_pointer_meta` from a + // `std::string_view` object. + // + // [1]: mediapipe/tasks/cc/metadata/utils/zip_utils.h + optional FilePointerMeta file_pointer_meta = 4; } // A proto defining file descriptor metadata for mapping file into memory using @@ -62,3 +70,14 @@ message FileDescriptorMeta { // offset of a given asset obtained from AssetFileDescriptor#getStartOffset(). optional int64 offset = 3; } + +// The pointer points to location of a file in memory. Make sure the file memory +// that it points locates on the same machine and it outlives this +// FilePointerMeta object. +message FilePointerMeta { + // Memory address of the file in decimal. + optional uint64 pointer = 1; + + // File length. + optional int64 length = 2; +} diff --git a/mediapipe/tasks/cc/metadata/BUILD b/mediapipe/tasks/cc/metadata/BUILD index c3555e4a0..ef32dd184 100644 --- a/mediapipe/tasks/cc/metadata/BUILD +++ b/mediapipe/tasks/cc/metadata/BUILD @@ -19,8 +19,9 @@ cc_library( deps = [ "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/metadata/utils:zip_readonly_mem_file", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -29,7 +30,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@flatbuffers//:runtime_cc", "@org_tensorflow//tensorflow/lite/schema:schema_fbs", - "@zlib//:zlib_minizip", ], ) diff --git a/mediapipe/tasks/cc/metadata/metadata_extractor.cc b/mediapipe/tasks/cc/metadata/metadata_extractor.cc index fcec49083..4bc3e8ba0 100644 --- a/mediapipe/tasks/cc/metadata/metadata_extractor.cc +++ b/mediapipe/tasks/cc/metadata/metadata_extractor.cc @@ -17,16 +17,16 @@ limitations under the License. #include +#include "absl/cleanup/cleanup.h" #include "absl/memory/memory.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" -#include "contrib/minizip/ioapi.h" -#include "contrib/minizip/unzip.h" #include "flatbuffers/flatbuffers.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -53,72 +53,6 @@ const T* GetItemFromVector( } return src_vector->Get(index); } - -// Wrapper function around calls to unzip to avoid repeating conversion logic -// from error code to Status. -absl::Status UnzipErrorToStatus(int error) { - if (error != UNZ_OK) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to read associated file in zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - return absl::OkStatus(); -} - -// Stores a file name, position in zip buffer and size. -struct ZipFileInfo { - std::string name; - ZPOS64_T position; - ZPOS64_T size; -}; - -// Returns the ZipFileInfo corresponding to the current file in the provided -// unzFile object. -absl::StatusOr GetCurrentZipFileInfo(const unzFile& zf) { - // Open file in raw mode, as data is expected to be uncompressed. - int method; - MP_RETURN_IF_ERROR(UnzipErrorToStatus( - unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1))); - if (method != Z_NO_COMPRESSION) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Expected uncompressed zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - - // Get file info a first time to get filename size. - unz_file_info64 file_info; - MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( - zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0, - /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, - /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); - - // Second call to get file name. - auto file_name_size = file_info.size_filename; - char* c_file_name = (char*)malloc(file_name_size); - MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( - zf, &file_info, c_file_name, file_name_size, - /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, - /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); - std::string file_name = std::string(c_file_name, file_name_size); - free(c_file_name); - - // Get position in file. - auto position = unzGetCurrentFileZStreamPos64(zf); - if (position == 0) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to read file in zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - - // Close file and return. - MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf))); - - ZipFileInfo result{}; - result.name = file_name; - result.position = position; - result.size = file_info.uncompressed_size; - return result; -} } // namespace /* static */ @@ -238,47 +172,15 @@ absl::Status ModelMetadataExtractor::InitFromModelBuffer( absl::Status ModelMetadataExtractor::ExtractAssociatedFiles( const char* buffer_data, size_t buffer_size) { - // Create in-memory read-only zip file. - ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size); - // Open zip. - unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def()); - if (zf == nullptr) { + auto status = + ExtractFilesfromZipFile(buffer_data, buffer_size, &associated_files_); + if (!status.ok() && + absl::StrContains(status.message(), "Unable to open zip archive.")) { // It's OK if it fails: this means there are no associated files with this // model. return absl::OkStatus(); } - // Get number of files. - unz_global_info global_info; - if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to get zip archive info.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - - // Browse through files in archive. - if (global_info.number_entry > 0) { - int error = unzGoToFirstFile(zf); - while (error == UNZ_OK) { - ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf)); - // Store result in map. - associated_files_[zip_file_info.name] = absl::string_view( - buffer_data + zip_file_info.position, zip_file_info.size); - error = unzGoToNextFile(zf); - } - if (error != UNZ_END_OF_LIST_OF_FILE) { - return CreateStatusWithPayload( - StatusCode::kUnknown, - "Unable to read associated file in zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - } - // Close zip. - if (unzClose(zf) != UNZ_OK) { - return CreateStatusWithPayload( - StatusCode::kUnknown, "Unable to close zip archive.", - MediaPipeTasksStatus::kMetadataAssociatedFileZipError); - } - return absl::OkStatus(); + return status; } absl::StatusOr ModelMetadataExtractor::GetAssociatedFile( diff --git a/mediapipe/tasks/cc/metadata/utils/BUILD b/mediapipe/tasks/cc/metadata/utils/BUILD index b595eb10f..881b88962 100644 --- a/mediapipe/tasks/cc/metadata/utils/BUILD +++ b/mediapipe/tasks/cc/metadata/utils/BUILD @@ -24,3 +24,20 @@ cc_library( "@zlib//:zlib_minizip", ], ) + +cc_library( + name = "zip_utils", + srcs = ["zip_utils.cc"], + hdrs = ["zip_utils.h"], + deps = [ + ":zip_readonly_mem_file", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@zlib//:zlib_minizip", + ], +) diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.cc b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc new file mode 100644 index 000000000..2c09e1961 --- /dev/null +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.cc @@ -0,0 +1,181 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" + +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "contrib/minizip/ioapi.h" +#include "contrib/minizip/unzip.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_readonly_mem_file.h" + +namespace mediapipe { +namespace tasks { +namespace metadata { + +namespace { + +using ::absl::StatusCode; + +// Wrapper function around calls to unzip to avoid repeating conversion logic +// from error code to Status. +absl::Status UnzipErrorToStatus(int error) { + if (error != UNZ_OK) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to read the file in zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + return absl::OkStatus(); +} + +// Stores a file name, position in zip buffer and size. +struct ZipFileInfo { + std::string name; + ZPOS64_T position; + ZPOS64_T size; +}; + +// Returns the ZipFileInfo corresponding to the current file in the provided +// unzFile object. +absl::StatusOr GetCurrentZipFileInfo(const unzFile& zf) { + // Open file in raw mode, as data is expected to be uncompressed. + int method; + MP_RETURN_IF_ERROR(UnzipErrorToStatus( + unzOpenCurrentFile2(zf, &method, /*level=*/nullptr, /*raw=*/1))); + absl::Cleanup unzipper_closer = [zf]() { + auto status = UnzipErrorToStatus(unzCloseCurrentFile(zf)); + if (!status.ok()) { + LOG(ERROR) << "Failed to close the current zip file: " << status; + } + }; + if (method != Z_NO_COMPRESSION) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Expected uncompressed zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + + // Get file info a first time to get filename size. + unz_file_info64 file_info; + MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( + zf, &file_info, /*szFileName=*/nullptr, /*szFileNameBufferSize=*/0, + /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, + /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); + + // Second call to get file name. + auto file_name_size = file_info.size_filename; + char* c_file_name = (char*)malloc(file_name_size); + MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzGetCurrentFileInfo64( + zf, &file_info, c_file_name, file_name_size, + /*extraField=*/nullptr, /*extraFieldBufferSize=*/0, + /*szComment=*/nullptr, /*szCommentBufferSize=*/0))); + std::string file_name = std::string(c_file_name, file_name_size); + free(c_file_name); + + // Get position in file. + auto position = unzGetCurrentFileZStreamPos64(zf); + if (position == 0) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to read file in zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + + // Perform the cleanup manually for error propagation. + std::move(unzipper_closer).Cancel(); + // Close file and return. + MP_RETURN_IF_ERROR(UnzipErrorToStatus(unzCloseCurrentFile(zf))); + + ZipFileInfo result{}; + result.name = file_name; + result.position = position; + result.size = file_info.uncompressed_size; + return result; +} + +} // namespace + +absl::Status ExtractFilesfromZipFile( + const char* buffer_data, const size_t buffer_size, + absl::flat_hash_map* files) { + // Create in-memory read-only zip file. + ZipReadOnlyMemFile mem_file = ZipReadOnlyMemFile(buffer_data, buffer_size); + // Open zip. + unzFile zf = unzOpen2_64(/*path=*/nullptr, &mem_file.GetFileFunc64Def()); + if (zf == nullptr) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to open zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + absl::Cleanup unzipper_closer = [zf]() { + if (unzClose(zf) != UNZ_OK) { + LOG(ERROR) << "Unable to close zip archive."; + } + }; + // Get number of files. + unz_global_info global_info; + if (unzGetGlobalInfo(zf, &global_info) != UNZ_OK) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to get zip archive info.", + MediaPipeTasksStatus::kFileZipError); + } + + // Browse through files in archive. + if (global_info.number_entry > 0) { + int error = unzGoToFirstFile(zf); + while (error == UNZ_OK) { + ASSIGN_OR_RETURN(auto zip_file_info, GetCurrentZipFileInfo(zf)); + // Store result in map. + (*files)[zip_file_info.name] = absl::string_view( + buffer_data + zip_file_info.position, zip_file_info.size); + error = unzGoToNextFile(zf); + } + if (error != UNZ_END_OF_LIST_OF_FILE) { + return CreateStatusWithPayload( + StatusCode::kUnknown, + "Unable to read associated file in zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + } + // Perform the cleanup manually for error propagation. + std::move(unzipper_closer).Cancel(); + // Close zip. + if (unzClose(zf) != UNZ_OK) { + return CreateStatusWithPayload(StatusCode::kUnknown, + "Unable to close zip archive.", + MediaPipeTasksStatus::kFileZipError); + } + return absl::OkStatus(); +} + +void SetExternalFile(const absl::string_view& file_content, + core::proto::ExternalFile* model_file, bool is_copy) { + if (is_copy) { + std::string str_content{file_content}; + model_file->set_file_content(str_content); + } else { + auto pointer = reinterpret_cast(file_content.data()); + model_file->mutable_file_pointer_meta()->set_pointer(pointer); + model_file->mutable_file_pointer_meta()->set_length(file_content.length()); + } +} + +} // namespace metadata +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/metadata/utils/zip_utils.h b/mediapipe/tasks/cc/metadata/utils/zip_utils.h new file mode 100644 index 000000000..10ad0a5a9 --- /dev/null +++ b/mediapipe/tasks/cc/metadata/utils/zip_utils.h @@ -0,0 +1,50 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" + +namespace mediapipe { +namespace tasks { +namespace metadata { + +// Extract files from the zip file. +// Input: Pointer and length of the zip file in memory. +// Outputs: A map with the filename as key and a pointer to the file contents +// as value. The file contents returned by this function are only guaranteed to +// stay valid while buffer_data is alive. +absl::Status ExtractFilesfromZipFile( + const char* buffer_data, const size_t buffer_size, + absl::flat_hash_map* files); + +// Set the ExternalFile object by file_content in memory. By default, +// `is_copy=false` which means to set `file_pointer_meta` in ExternalFile which +// is the pointer points to location of a file in memory. Otherwise, if +// `is_copy=true`, copy the memory into `file_content` in ExternalFile. +void SetExternalFile(const absl::string_view& file_content, + core::proto::ExternalFile* model_file, + bool is_copy = false); + +} // namespace metadata +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_METADATA_UTILS_ZIP_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD new file mode 100644 index 000000000..336b1bb45 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -0,0 +1,107 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "text_classifier_graph", + srcs = ["text_classifier_graph.cc"], + deps = [ + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/tasks/cc/components:text_preprocessing_graph", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_calculator", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_library( + name = "text_classifier", + srcs = ["text_classifier.cc"], + hdrs = ["text_classifier.h"], + deps = [ + ":text_classifier_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_api_factory", + "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + ], +) + +cc_test( + name = "text_classifier_test", + srcs = ["text_classifier_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:bert_text_classifier_models", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + deps = [ + ":text_classifier", + ":text_classifier_test_utils", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + ], +) + +cc_library( + name = "text_classifier_test_utils", + srcs = ["text_classifier_test_utils.cc"], + hdrs = ["text_classifier_test_utils.h"], + visibility = ["//visibility:private"], + deps = [ + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/lite:mutable_op_resolver", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], +) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD similarity index 58% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD rename to mediapipe/tasks/cc/text/text_classifier/proto/BUILD index f3927727e..f2b544d87 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/proto/BUILD @@ -14,30 +14,17 @@ load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") -package(default_visibility = [ - "//mediapipe/tasks:internal", -]) +package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) mediapipe_proto_library( - name = "hand_gesture_recognizer_subgraph_options_proto", - srcs = ["hand_gesture_recognizer_subgraph_options.proto"], + name = "text_classifier_graph_options_proto", + srcs = ["text_classifier_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", - "//mediapipe/tasks/cc/core/proto:base_options_proto", - ], -) - -mediapipe_proto_library( - name = "landmarks_to_matrix_calculator_proto", - srcs = ["landmarks_to_matrix_calculator.proto"], - deps = [ - "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto similarity index 61% rename from mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto rename to mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto index b3d82eda4..8f4d7eea6 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_options.proto +++ b/mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.proto @@ -15,26 +15,24 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.vision.hand_landmarker.proto; +package mediapipe.tasks.text.text_classifier.proto; import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto"; -import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto"; -message HandLandmarkerOptions { +option java_package = "com.google.mediapipe.tasks.text.textclassifier.proto"; +option java_outer_classname = "TextClassifierGraphOptionsProto"; + +message TextClassifierGraphOptions { extend mediapipe.CalculatorOptions { - optional HandLandmarkerOptions ext = 462713202; + optional TextClassifierGraphOptions ext = 462704549; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; - - optional hand_detector.proto.HandDetectorOptions hand_detector_options = 3; - - optional HandLandmarkerSubgraphOptions hand_landmarker_subgraph_options = 4; + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + optional components.processors.proto.ClassifierOptions classifier_options = 2; } diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc new file mode 100644 index 000000000..699f15bc0 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc @@ -0,0 +1,104 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/task_api_factory.h" +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" +#include "tensorflow/lite/core/api/op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +namespace { + +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; + +constexpr char kTextStreamName[] = "text_in"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kClassificationResultStreamName[] = "classification_result_out"; +constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// type "TextClassifierGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kSubgraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); + subgraph.Out(kClassificationResultTag) + .SetName(kClassificationResultStreamName) >> + graph.Out(kClassificationResultTag); + return graph.GetConfig(); +} + +// Converts the user-facing TextClassifierOptions struct to the internal +// TextClassifierGraphOptions proto. +std::unique_ptr +ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + auto classifier_options_proto = + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( + &(options->classifier_options))); + options_proto->mutable_classifier_options()->Swap( + classifier_options_proto.get()); + return options_proto; +} + +} // namespace + +absl::StatusOr> TextClassifier::Create( + std::unique_ptr options) { + auto options_proto = ConvertTextClassifierOptionsToProto(options.get()); + return core::TaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver)); +} + +absl::StatusOr TextClassifier::Classify( + absl::string_view text) { + ASSIGN_OR_RETURN( + auto output_packets, + runner_->Process( + {{kTextStreamName, MakePacket(std::string(text))}})); + return output_packets[kClassificationResultStreamName] + .Get(); +} + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h new file mode 100644 index 000000000..b027a9787 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h @@ -0,0 +1,96 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +// The options for configuring a MediaPipe text classifier task. +struct TextClassifierOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // Options for configuring the classifier behavior, such as score threshold, + // number of results, etc. + components::processors::ClassifierOptions classifier_options; +}; + +// Performs classification on text. +// +// This API expects a TFLite model with (optional) TFLite Model Metadata that +// contains the mandatory (described below) input tensors, output tensor, +// and the optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for +// models with int32 input tensors because it contains the input process unit +// for the model's Tokenizer. No metadata is required for models with string +// input tensors. +// +// Input tensors: +// (kTfLiteInt32) +// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing +// the input ids, segment ids, and mask ids +// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the +// input ids +// or (kTfLiteString) +// - 1 input tensor that is shapeless or has shape [1] containing the input +// string +// At least one output tensor with: +// (kTfLiteFloat32/kBool) +// - `[1 x N]` array with `N` represents the number of categories. +// - optional (but recommended) label items as AssociatedFiles with type +// TENSOR_AXIS_LABELS, containing one label per line. The first such +// AssociatedFile (if any) is used to fill the `category_name` field of the +// results. The `display_name` field is filled from the AssociatedFile (if +// any) whose locale matches the `display_names_locale` field of the +// `TextClassifierOptions` used at creation time ("en" by default, i.e. +// English). If none of these are available, only the `index` field of the +// results will be filled. +class TextClassifier : core::BaseTaskApi { + public: + using BaseTaskApi::BaseTaskApi; + + // Creates a TextClassifier from the provided `options`. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs classification on the input `text`. + absl::StatusOr Classify( + absl::string_view text); + + // Shuts down the TextClassifier when all the work is done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_ diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc new file mode 100644 index 000000000..9706db4d8 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -0,0 +1,162 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" +#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::core::ModelResources; + +constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kTextTag[] = "TEXT"; +constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; +constexpr char kTensorsTag[] = "TENSORS"; + +} // namespace + +// A "TextClassifierGraph" performs Natural Language classification (including +// BERT-based text classification). +// - Accepts input text and outputs classification results on CPU. +// +// Inputs: +// TEXT - std::string +// Input text to perform classification on. +// +// Outputs: +// CLASSIFICATION_RESULT - ClassificationResult +// The aggregated classification result object that has 3 dimensions: +// (classification head, classification timestamp, classification category). +// +// Example: +// node { +// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph" +// input_stream: "TEXT:text_in" +// output_stream: "CLASSIFICATION_RESULT:classification_result_out" +// options { +// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// } +// } +// } +class TextClassifierGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN( + const ModelResources* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN( + Source classification_result_out, + BuildTextClassifierTask( + sc->Options(), *model_resources, + graph[Input(kTextTag)], graph)); + classification_result_out >> + graph[Output(kClassificationResultTag)]; + return graph.GetConfig(); + } + + private: + // Adds a mediapipe TextClassifier task graph into the provided + // builder::Graph instance. The TextClassifier task takes an input + // text (std::string) and returns one classification result per output head + // specified by the model. + // + // task_options: the mediapipe tasks TextClassifierGraphOptions proto. + // model_resources: the ModelResources object initialized from a + // TextClassifier model file with model metadata. + // text_in: (std::string) stream to run text classification on. + // graph: the mediapipe builder::Graph instance to be updated. + absl::StatusOr> BuildTextClassifierTask( + const proto::TextClassifierGraphOptions& task_options, + const ModelResources& model_resources, Source text_in, + Graph& graph) { + // Adds preprocessing calculators and connects them to the text input + // stream. + auto& preprocessing = + graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph"); + MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph( + model_resources, + preprocessing.GetOptions< + tasks::components::proto::TextPreprocessingGraphOptions>())); + text_in >> preprocessing.In(kTextTag); + + // Adds both InferenceCalculator and ModelResourcesCalculator. + auto& inference = AddInference( + model_resources, task_options.base_options().acceleration(), graph); + // The metadata extractor side-output comes from the + // ModelResourcesCalculator. + inference.SideOut(kMetadataExtractorTag) >> + preprocessing.SideIn(kMetadataExtractorTag); + preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag); + + // Adds postprocessing calculators and connects them to the graph output. + auto& postprocessing = graph.AddNode( + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); + inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); + + // Outputs the aggregated classification result as the subgraph output + // stream. + return postprocessing[Output( + kClassificationResultTag)]; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::text::text_classifier::TextClassifierGraph); + +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc new file mode 100644 index 000000000..62837be8c --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace text_classifier { +namespace { + +using ::mediapipe::EqualsProto; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::kMediaPipeTasksPayload; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::testing::HasSubstr; +using ::testing::Optional; + +constexpr float kEpsilon = 0.001; +constexpr int kMaxSeqLen = 128; +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; +constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; +constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite"; +constexpr char kTestRegexModelPath[] = + "test_model_text_classifier_with_regex_tokenizer.tflite"; +constexpr char kStringToBoolModelPath[] = + "test_model_text_classifier_bool_output.tflite"; + +std::string GetFullPath(absl::string_view file_name) { + return JoinPath("./", kTestDataDirectory, file_name); +} + +class TextClassifierTest : public tflite_shims::testing::Test {}; + +TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK(TextClassifier::Create(std::move(options))); +} + +TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) { + auto options = std::make_unique(); + StatusOr> classifier = + TextClassifier::Create(std::move(options)); + + EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + classifier.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(TextClassifierTest, CreateFailsWithMissingModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kInvalidModelPath); + StatusOr> classifier = + TextClassifier::Create(std::move(options)); + + EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound); + EXPECT_THAT(classifier.status().message(), + HasSubstr("Unable to open file at")); + EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); + MP_ASSERT_OK(TextClassifier::Create(std::move(options))); +} + +} // namespace +} // namespace text_classifier +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc new file mode 100644 index 000000000..d12370372 --- /dev/null +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.cc @@ -0,0 +1,131 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" +#include "tensorflow/lite/string_util.h" + +namespace mediapipe { +namespace tasks { +namespace text { +namespace { + +using ::mediapipe::tasks::CreateStatusWithPayload; +using ::tflite::GetInput; +using ::tflite::GetOutput; +using ::tflite::GetString; +using ::tflite::StringRef; + +constexpr absl::string_view kInputStr = "hello"; +constexpr bool kBooleanData[] = {true, true, false}; +constexpr size_t kBooleanDataSize = std::size(kBooleanData); + +// Checks and returns type of a tensor, fails if tensor type is not T. +template +absl::StatusOr AssertAndReturnTypedTensor(const TfLiteTensor* tensor) { + if (!tensor->data.raw) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Tensor (%s) has no raw data.", tensor->name)); + } + + // Checks if data type of tensor is T and returns the pointer casted to T if + // applicable, returns nullptr if tensor type is not T. + // See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType. + if (tensor->type == tflite::typeToTfLiteType()) { + return reinterpret_cast(tensor->data.raw); + } + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.", + tensor->name, tflite::typeToTfLiteType(), + tensor->bytes)); +} + +// Populates tensor with array of data, fails if data type doesn't match tensor +// type or they don't have the same number of elements. +template >>> +absl::Status PopulateTensor(const T* data, int num_elements, + TfLiteTensor* tensor) { + ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor(tensor)); + size_t bytes = num_elements * sizeof(T); + if (tensor->bytes != bytes) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes, + bytes)); + } + std::memcpy(v, data, bytes); + return absl::OkStatus(); +} + +TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TfLiteIntArray* dims = TfLiteIntArrayCreate(1); + dims->data[0] = kBooleanDataSize; + return context->ResizeTensor(context, output, dims); +} + +TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input_tensor = GetInput(context, node, 0); + TF_LITE_ENSURE(context, input_tensor != nullptr); + StringRef input_str_ref = GetString(input_tensor, 0); + std::string input_str(input_str_ref.str, input_str_ref.len); + if (input_str != kInputStr) { + return kTfLiteError; + } + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE(context, output != nullptr); + TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok()); + return kTfLiteOk; +} + +// This custom op takes a string tensor in and outputs a bool tensor with +// value{true, true, false}, it's used to mimic a real text classification model +// which classifies a string into scores of different categories. +TfLiteRegistration* RegisterStringToBool() { + // Dummy implementation of custom OP + // This op takes string as input and outputs bool[] + static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr, + /* prepare= */ PrepareStringToBool, + /* invoke= */ InvokeStringToBool}; + return &r; +} +} // namespace + +std::unique_ptr CreateCustomResolver() { + tflite::MutableOpResolver resolver; + resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool()); + return std::make_unique(resolver); +} + +} // namespace text +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h similarity index 59% rename from mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h rename to mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h index a55661fa3..a427b561c 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h @@ -13,22 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ +#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ -#include "tensorflow/lite/kernels/register.h" +#include + +#include "tensorflow/lite/mutable_op_resolver.h" namespace mediapipe { namespace tasks { -namespace vision { -class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver { - public: - HandDetectorOpResolver(); - HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete; -}; +namespace text { -} // namespace vision +// Create a custom MutableOpResolver to provide custom OP implementations to +// mimic classification behavior. +std::unique_ptr CreateCustomResolver(); + +} // namespace text } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ +#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/tokenizers/BUILD b/mediapipe/tasks/cc/text/tokenizers/BUILD index 048c7021d..5ce08b2d7 100644 --- a/mediapipe/tasks/cc/text/tokenizers/BUILD +++ b/mediapipe/tasks/cc/text/tokenizers/BUILD @@ -73,7 +73,18 @@ cc_library( ], ) -# TODO: This test fails in OSS +cc_test( + name = "sentencepiece_tokenizer_test", + srcs = ["sentencepiece_tokenizer_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:albert_model", + ], + deps = [ + ":sentencepiece_tokenizer", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/core:utils", + ], +) cc_library( name = "tokenizer_utils", @@ -97,7 +108,32 @@ cc_library( ], ) -# TODO: This test fails in OSS +cc_test( + name = "tokenizer_utils_test", + srcs = ["tokenizer_utils_test.cc"], + data = [ + "//mediapipe/tasks/testdata/text:albert_model", + "//mediapipe/tasks/testdata/text:mobile_bert_model", + "//mediapipe/tasks/testdata/text:text_classifier_models", + ], + linkopts = ["-ldl"], + deps = [ + ":bert_tokenizer", + ":regex_tokenizer", + ":sentencepiece_tokenizer", + ":tokenizer_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata:metadata_extractor", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + ], +) cc_library( name = "regex_tokenizer", diff --git a/mediapipe/tasks/cc/vision/core/BUILD b/mediapipe/tasks/cc/vision/core/BUILD index 12d789901..e8e197a1d 100644 --- a/mediapipe/tasks/cc/vision/core/BUILD +++ b/mediapipe/tasks/cc/vision/core/BUILD @@ -21,12 +21,23 @@ cc_library( hdrs = ["running_mode.h"], ) +cc_library( + name = "image_processing_options", + hdrs = ["image_processing_options.h"], + deps = [ + "//mediapipe/tasks/cc/components/containers:rect", + ], +) + cc_library( name = "base_vision_task_api", hdrs = ["base_vision_task_api.h"], deps = [ + ":image_processing_options", ":running_mode", "//mediapipe/calculators/core:flow_limiter_calculator", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:rect", "//mediapipe/tasks/cc/core:base_task_api", "//mediapipe/tasks/cc/core:task_runner", "@com_google_absl//absl/status", diff --git a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h index 4586cbbdd..c3c0a0261 100644 --- a/mediapipe/tasks/cc/vision/core/base_vision_task_api.h +++ b/mediapipe/tasks/cc/vision/core/base_vision_task_api.h @@ -16,15 +16,20 @@ limitations under the License. #ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ #define MEDIAPIPE_TASKS_CC_VISION_CORE_BASE_VISION_TASK_API_H_ +#include #include +#include #include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -87,6 +92,60 @@ class BaseVisionTaskApi : public tasks::core::BaseTaskApi { return runner_->Send(std::move(inputs)); } + // Convert from ImageProcessingOptions to NormalizedRect, performing sanity + // checks on-the-fly. If the input ImageProcessingOptions is not present, + // returns a default NormalizedRect covering the whole image with rotation set + // to 0. If 'roi_allowed' is false, an error will be returned if the input + // ImageProcessingOptions has its 'region_or_interest' field set. + static absl::StatusOr ConvertToNormalizedRect( + std::optional options, bool roi_allowed = true) { + mediapipe::NormalizedRect normalized_rect; + normalized_rect.set_rotation(0); + normalized_rect.set_x_center(0.5); + normalized_rect.set_y_center(0.5); + normalized_rect.set_width(1.0); + normalized_rect.set_height(1.0); + if (!options.has_value()) { + return normalized_rect; + } + + if (options->rotation_degrees % 90 != 0) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected rotation to be a multiple of 90°.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + // Convert to radians counter-clockwise. + normalized_rect.set_rotation(-options->rotation_degrees * M_PI / 180.0); + + if (options->region_of_interest.has_value()) { + if (!roi_allowed) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "This task doesn't support region-of-interest.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + auto& roi = *options->region_of_interest; + if (roi.left >= roi.right || roi.top >= roi.bottom) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected Rect with left < right and top < bottom.", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + if (roi.left < 0 || roi.top < 0 || roi.right > 1 || roi.bottom > 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Expected Rect values to be in [0,1].", + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError); + } + normalized_rect.set_x_center((roi.left + roi.right) / 2.0); + normalized_rect.set_y_center((roi.top + roi.bottom) / 2.0); + normalized_rect.set_width(roi.right - roi.left); + normalized_rect.set_height(roi.bottom - roi.top); + } + return normalized_rect; + } + private: RunningMode running_mode_; }; diff --git a/mediapipe/tasks/cc/vision/core/image_processing_options.h b/mediapipe/tasks/cc/vision/core/image_processing_options.h new file mode 100644 index 000000000..7e764c1fe --- /dev/null +++ b/mediapipe/tasks/cc/vision/core/image_processing_options.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ + +#include + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace core { + +// Options for image processing. +// +// If both region-or-interest and rotation are specified, the crop around the +// region-of-interest is extracted first, the the specified rotation is applied +// to the crop. +struct ImageProcessingOptions { + // The optional region-of-interest to crop from the image. If not specified, + // the full image is used. + // + // Coordinates must be in [0,1] with 'left' < 'right' and 'top' < bottom. + std::optional region_of_interest = std::nullopt; + + // The rotation to apply to the image (or cropped region-of-interest), in + // degrees clockwise. + // + // The rotation must be a multiple (positive or negative) of 90°. + int rotation_degrees = 0; +}; + +} // namespace core +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_CORE_IMAGE_PROCESSING_OPTIONS_H_ diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD new file mode 100644 index 000000000..6296017d4 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/BUILD @@ -0,0 +1,165 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +cc_library( + name = "handedness_util", + srcs = ["handedness_util.cc"], + hdrs = ["handedness_util.h"], + deps = [ + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:ret_check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "handedness_util_test", + srcs = ["handedness_util_test.cc"], + deps = [ + ":handedness_util", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_library( + name = "hand_gesture_recognizer_graph", + srcs = ["hand_gesture_recognizer_graph.cc"], + deps = [ + "//mediapipe/calculators/core:begin_loop_calculator", + "//mediapipe/calculators/core:concatenate_vector_calculator", + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/core:get_vector_item_calculator", + "//mediapipe/calculators/tensor:tensor_converter_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator", + "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:handedness_to_matrix_calculator", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator", + "//mediapipe/tasks/cc/vision/gesture_recognizer/calculators:landmarks_to_matrix_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_library( + name = "gesture_recognizer_graph", + srcs = ["gesture_recognizer_graph.cc"], + deps = [ + ":hand_gesture_recognizer_graph", + "//mediapipe/calculators/core:vector_indices_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarks_detector_graph", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "//mediapipe/tasks/metadata:metadata_schema_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) + +cc_library( + name = "gesture_recognizer", + srcs = ["gesture_recognizer.cc"], + hdrs = ["gesture_recognizer.h"], + deps = [ + ":gesture_recognizer_graph", + ":hand_gesture_recognizer_graph", + "//mediapipe/framework:packet", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components:image_preprocessing", + "//mediapipe/tasks/cc/components/containers:gesture_recognition_result", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:model_resources", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@org_tensorflow//tensorflow/lite/core/api:op_resolver", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], +) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD similarity index 81% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD index 4863c8682..08f7f45d0 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/BUILD @@ -12,11 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + package(default_visibility = [ "//mediapipe/app/xeno:__subpackages__", "//mediapipe/tasks:internal", ]) +mediapipe_proto_library( + name = "landmarks_to_matrix_calculator_proto", + srcs = ["landmarks_to_matrix_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + cc_library( name = "handedness_to_matrix_calculator", srcs = ["handedness_to_matrix_calculator.cc"], @@ -25,7 +37,7 @@ cc_library( "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:matrix", "//mediapipe/framework/port:ret_check", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer:handedness_util", + "//mediapipe/tasks/cc/vision/gesture_recognizer:handedness_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -53,11 +65,12 @@ cc_library( name = "landmarks_to_matrix_calculator", srcs = ["landmarks_to_matrix_calculator.cc"], deps = [ + ":landmarks_to_matrix_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:ret_check", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -74,6 +87,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "@com_google_absl//absl/strings", diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc similarity index 90% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc index 746293d21..b6c973a1b 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator.cc @@ -26,14 +26,16 @@ limitations under the License. #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" +// TODO Update to use API2 namespace mediapipe { -namespace tasks { -namespace vision { +namespace api2 { namespace { +using ::mediapipe::tasks::vision::gesture_recognizer::GetLeftHandScore; + constexpr char kHandednessTag[] = "HANDEDNESS"; constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX"; @@ -71,6 +73,8 @@ class HandednessToMatrixCalculator : public CalculatorBase { return absl::OkStatus(); } + // TODO remove this after change to API2, because Setting offset + // to 0 is the default in API2 absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); return absl::OkStatus(); @@ -95,6 +99,5 @@ absl::Status HandednessToMatrixCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -} // namespace vision -} // namespace tasks +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc similarity index 97% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc index c93c48ac5..17b16bf80 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/handedness_to_matrix_calculator_test.cc @@ -28,8 +28,6 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { -namespace tasks { -namespace vision { namespace { @@ -95,6 +93,4 @@ INSTANTIATE_TEST_CASE_P( } // namespace -} // namespace vision -} // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc similarity index 83% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc index 990e99920..277bb170a 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -26,20 +27,20 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +// TODO Update to use API2 namespace mediapipe { -namespace tasks { -namespace vision { - -using proto::LandmarksToMatrixCalculatorOptions; +namespace api2 { namespace { constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; constexpr int kFeaturesPerLandmark = 3; @@ -64,6 +65,25 @@ absl::StatusOr NormalizeLandmarkAspectRatio( return normalized_landmarks; } +template +absl::StatusOr RotateLandmarks(const LandmarkListT& landmarks, + float rotation) { + float cos = std::cos(rotation); + // Negate because Y-axis points down and not up. + float sin = std::sin(-rotation); + LandmarkListT rotated_landmarks; + for (int i = 0; i < landmarks.landmark_size(); ++i) { + const auto& old_landmark = landmarks.landmark(i); + float x = old_landmark.x() - 0.5; + float y = old_landmark.y() - 0.5; + auto* new_landmark = rotated_landmarks.add_landmark(); + new_landmark->set_x(x * cos - y * sin + 0.5); + new_landmark->set_y(y * cos + x * sin + 0.5); + new_landmark->set_z(old_landmark.z()); + } + return rotated_landmarks; +} + template absl::StatusOr NormalizeObject(const LandmarkListT& landmarks, int origin_offset) { @@ -136,6 +156,13 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { NormalizeLandmarkAspectRatio(landmarks, width, height)); } + if (cc->Inputs().HasTag(kNormRectTag)) { + RET_CHECK(!cc->Inputs().Tag(kNormRectTag).IsEmpty()); + const auto rotation = + cc->Inputs().Tag(kNormRectTag).Get().rotation(); + ASSIGN_OR_RETURN(landmarks, RotateLandmarks(landmarks, rotation)); + } + const auto& options = cc->Options(); if (options.object_normalization()) { ASSIGN_OR_RETURN( @@ -165,6 +192,8 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { // WORLD_LANDMARKS - World 3d landmarks of one object. Use *either* // LANDMARKS or WORLD_LANDMARKS. // IMAGE_SIZE - (width, height) of the image +// NORM_RECT - Optional NormalizedRect object whose 'rotation' field is used +// to rotate the landmarks. // Output: // LANDMARKS_MATRIX - Matrix for the landmarks. // @@ -175,7 +204,7 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) { // input_stream: "IMAGE_SIZE:image_size" // output_stream: "LANDMARKS_MATRIX:landmarks_matrix" // options { -// [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions.ext] { +// [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { // object_normalization: true // object_normalization_origin_offset: 0 // } @@ -187,6 +216,7 @@ class LandmarksToMatrixCalculator : public CalculatorBase { cc->Inputs().Tag(kLandmarksTag).Set().Optional(); cc->Inputs().Tag(kWorldLandmarksTag).Set().Optional(); cc->Inputs().Tag(kImageSizeTag).Set>().Optional(); + cc->Inputs().Tag(kNormRectTag).Set().Optional(); cc->Outputs().Tag(kLandmarksMatrixTag).Set(); return absl::OkStatus(); } @@ -221,6 +251,5 @@ absl::Status LandmarksToMatrixCalculator::Process(CalculatorContext* cc) { return absl::OkStatus(); } -} // namespace vision -} // namespace tasks +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto similarity index 97% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto index 6b004e203..10b034447 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto2"; -package mediapipe.tasks.vision.proto; +package mediapipe; import "mediapipe/framework/calculator.proto"; diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc similarity index 81% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc index 05d238f66..fe6f1162b 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -23,13 +24,12 @@ limitations under the License. #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { -namespace tasks { -namespace vision { namespace { @@ -37,6 +37,7 @@ constexpr char kLandmarksTag[] = "LANDMARKS"; constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; +constexpr char kNormRectTag[] = "NORM_RECT"; template LandmarkListT BuildPseudoLandmarks(int num_landmarks, int offset = 0) { @@ -56,6 +57,7 @@ struct Landmarks2dToMatrixCalculatorTestCase { int object_normalization_origin_offset = -1; float expected_cell_0_2; float expected_cell_1_5; + float rotation; }; using Landmarks2dToMatrixCalculatorTest = @@ -70,10 +72,10 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { calculator: "LandmarksToMatrixCalculator" input_stream: "LANDMARKS:landmarks" input_stream: "IMAGE_SIZE:image_size" + input_stream: "NORM_RECT:norm_rect" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { - [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions - .ext] { + [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { object_normalization: $0 object_normalization_origin_offset: $1 } @@ -94,6 +96,11 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) { runner.MutableInputs() ->Tag(kImageSizeTag) .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); + auto norm_rect = std::make_unique(); + norm_rect->set_rotation(test_case.rotation); + runner.MutableInputs() + ->Tag(kNormRectTag) + .packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; @@ -112,12 +119,20 @@ INSTANTIATE_TEST_CASE_P( .base_offset = 0, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.1875f}, + .expected_cell_1_5 = 0.1875f, + .rotation = 0}, {.test_name = "TestWithOffset21", .base_offset = 21, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.1875f}}), + .expected_cell_1_5 = 0.1875f, + .rotation = 0}, + {.test_name = "TestWithRotation", + .base_offset = 0, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.075f, + .expected_cell_1_5 = -0.25f, + .rotation = M_PI / 2.0}}), [](const testing::TestParamInfo< Landmarks2dToMatrixCalculatorTest::ParamType>& info) { return info.param.test_name; @@ -129,6 +144,7 @@ struct LandmarksWorld3dToMatrixCalculatorTestCase { int object_normalization_origin_offset = -1; float expected_cell_0_2; float expected_cell_1_5; + float rotation; }; using LandmarksWorld3dToMatrixCalculatorTest = @@ -143,10 +159,10 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { calculator: "LandmarksToMatrixCalculator" input_stream: "WORLD_LANDMARKS:landmarks" input_stream: "IMAGE_SIZE:image_size" + input_stream: "NORM_RECT:norm_rect" output_stream: "LANDMARKS_MATRIX:landmarks_matrix" options { - [mediapipe.tasks.vision.proto.LandmarksToMatrixCalculatorOptions - .ext] { + [mediapipe.LandmarksToMatrixCalculatorOptions.ext] { object_normalization: $0 object_normalization_origin_offset: $1 } @@ -166,6 +182,11 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) { runner.MutableInputs() ->Tag(kImageSizeTag) .packets.push_back(Adopt(image_size.release()).At(Timestamp(0))); + auto norm_rect = std::make_unique(); + norm_rect->set_rotation(test_case.rotation); + runner.MutableInputs() + ->Tag(kNormRectTag) + .packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; @@ -184,17 +205,26 @@ INSTANTIATE_TEST_CASE_P( .base_offset = 0, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.25}, + .expected_cell_1_5 = 0.25, + .rotation = 0}, {.test_name = "TestWithOffset21", .base_offset = 21, .object_normalization_origin_offset = 0, .expected_cell_0_2 = 0.1f, - .expected_cell_1_5 = 0.25}, + .expected_cell_1_5 = 0.25, + .rotation = 0}, {.test_name = "NoObjectNormalization", .base_offset = 0, .object_normalization_origin_offset = -1, .expected_cell_0_2 = 0.021f, - .expected_cell_1_5 = 0.052f}}), + .expected_cell_1_5 = 0.052f, + .rotation = 0}, + {.test_name = "TestWithRotation", + .base_offset = 0, + .object_normalization_origin_offset = 0, + .expected_cell_0_2 = 0.1f, + .expected_cell_1_5 = -0.25f, + .rotation = M_PI / 2.0}}), [](const testing::TestParamInfo< LandmarksWorld3dToMatrixCalculatorTest::ParamType>& info) { return info.param.test_name; @@ -202,6 +232,4 @@ INSTANTIATE_TEST_CASE_P( } // namespace -} // namespace vision -} // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc new file mode 100644 index 000000000..d4ab16ac8 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.cc @@ -0,0 +1,306 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace gesture_recognizer { + +namespace { + +using GestureRecognizerGraphOptionsProto = ::mediapipe::tasks::vision:: + gesture_recognizer::proto::GestureRecognizerGraphOptions; + +using ::mediapipe::tasks::components::containers::GestureRecognitionResult; + +constexpr char kHandGestureSubgraphTypeName[] = + "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kHandGesturesTag[] = "HAND_GESTURES"; +constexpr char kHandGesturesStreamName[] = "hand_gestures"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kHandednessStreamName[] = "handedness"; +constexpr char kHandLandmarksTag[] = "LANDMARKS"; +constexpr char kHandLandmarksStreamName[] = "landmarks"; +constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running +// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the +// number of frames in flight. +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options, + bool enable_flow_limiting) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); + subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >> + graph.Out(kHandGesturesTag); + subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >> + graph.Out(kHandednessTag); + subgraph.Out(kHandLandmarksTag).SetName(kHandLandmarksStreamName) >> + graph.Out(kHandLandmarksTag); + subgraph.Out(kHandWorldLandmarksTag).SetName(kHandWorldLandmarksStreamName) >> + graph.Out(kHandWorldLandmarksTag); + subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator( + graph, subgraph, {kImageTag, kNormRectTag}, kHandGesturesTag); + } + graph.In(kImageTag) >> subgraph.In(kImageTag); + graph.In(kNormRectTag) >> subgraph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing GestureRecognizerOptions struct to the internal +// GestureRecognizerGraphOptions proto. +std::unique_ptr +ConvertGestureRecognizerGraphOptionsProto(GestureRecognizerOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + + // Configure hand detector options. + auto* hand_detector_graph_options = + options_proto->mutable_hand_landmarker_graph_options() + ->mutable_hand_detector_graph_options(); + hand_detector_graph_options->set_num_hands(options->num_hands); + hand_detector_graph_options->set_min_detection_confidence( + options->min_hand_detection_confidence); + + // Configure hand landmark detector options. + auto* hand_landmarker_graph_options = + options_proto->mutable_hand_landmarker_graph_options(); + hand_landmarker_graph_options->set_min_tracking_confidence( + options->min_tracking_confidence); + auto* hand_landmarks_detector_graph_options = + hand_landmarker_graph_options + ->mutable_hand_landmarks_detector_graph_options(); + hand_landmarks_detector_graph_options->set_min_detection_confidence( + options->min_hand_presence_confidence); + + // Configure hand gesture recognizer options. + auto* hand_gesture_recognizer_graph_options = + options_proto->mutable_hand_gesture_recognizer_graph_options(); + if (options->min_gesture_confidence >= 0) { + hand_gesture_recognizer_graph_options + ->mutable_canned_gesture_classifier_graph_options() + ->mutable_classifier_options() + ->set_score_threshold(options->min_gesture_confidence); + } + return options_proto; +} + +} // namespace + +absl::StatusOr> GestureRecognizer::Create( + std::unique_ptr options) { + auto options_proto = ConvertGestureRecognizerGraphOptionsProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = [=](absl::StatusOr + status_or_packets) { + if (!status_or_packets.ok()) { + Image image; + result_callback(status_or_packets.status(), image, + Timestamp::Unset().Value()); + return; + } + if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { + return; + } + Packet image_packet = status_or_packets.value()[kImageOutStreamName]; + if (status_or_packets.value()[kHandGesturesStreamName].IsEmpty()) { + Packet empty_packet = + status_or_packets.value()[kHandGesturesStreamName]; + result_callback( + {{{}, {}, {}, {}}}, image_packet.Get(), + empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + return; + } + Packet gesture_packet = + status_or_packets.value()[kHandGesturesStreamName]; + Packet handedness_packet = + status_or_packets.value()[kHandednessStreamName]; + Packet hand_landmarks_packet = + status_or_packets.value()[kHandLandmarksStreamName]; + Packet hand_world_landmarks_packet = + status_or_packets.value()[kHandWorldLandmarksStreamName]; + result_callback( + {{gesture_packet.Get>(), + handedness_packet.Get>(), + hand_landmarks_packet.Get>(), + hand_world_landmarks_packet.Get>()}}, + image_packet.Get(), + gesture_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + }; + } + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr GestureRecognizer::Recognize( + mediapipe::Image image, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "GPU input images are currently not supported.", + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); + if (output_packets[kHandGesturesStreamName].IsEmpty()) { + return {{{}, {}, {}, {}}}; + } + return { + {/* gestures= */ {output_packets[kHandGesturesStreamName] + .Get>()}, + /* handedness= */ + {output_packets[kHandednessStreamName] + .Get>()}, + /* hand_landmarks= */ + {output_packets[kHandLandmarksStreamName] + .Get>()}, + /* hand_world_landmarks */ + {output_packets[kHandWorldLandmarksStreamName] + .Get>()}}, + }; +} + +absl::StatusOr GestureRecognizer::RecognizeForVideo( + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + if (output_packets[kHandGesturesStreamName].IsEmpty()) { + return {{{}, {}, {}, {}}}; + } + return { + {/* gestures= */ {output_packets[kHandGesturesStreamName] + .Get>()}, + /* handedness= */ + {output_packets[kHandednessStreamName] + .Get>()}, + /* hand_landmarks= */ + {output_packets[kHandLandmarksStreamName] + .Get>()}, + /* hand_world_landmarks */ + {output_packets[kHandWorldLandmarksStreamName] + .Get>()}}, + }; +} + +absl::Status GestureRecognizer::RecognizeAsync( + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +} // namespace gesture_recognizer +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h new file mode 100644 index 000000000..3e281b26e --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer.h @@ -0,0 +1,195 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace gesture_recognizer { + +struct GestureRecognizerOptions { + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // GestureRecognizer has three running modes: + // 1) The image mode for recognizing hand gestures on single image inputs. + // 2) The video mode for recognizing hand gestures on the decoded frames of a + // video. + // 3) The live stream mode for recognizing hand gestures on the live stream of + // input data, such as from camera. In this mode, the "result_callback" + // below must be specified to receive the detection results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // The maximum number of hands can be detected by the GestureRecognizer. + int num_hands = 1; + + // The minimum confidence score for the hand detection to be considered + // successful. + float min_hand_detection_confidence = 0.5; + + // The minimum confidence score of hand presence score in the hand landmark + // detection. + float min_hand_presence_confidence = 0.5; + + // The minimum confidence score for the hand tracking to be considered + // successful. + float min_tracking_confidence = 0.5; + + // The minimum confidence score for the gestures to be considered + // successful. If < 0, the gesture confidence thresholds in the model + // metadata are used. + // TODO Note this option is subject to change, after scoring + // merging calculator is implemented. + float min_gesture_confidence = -1; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function, + const Image&, int64)> + result_callback = nullptr; +}; + +// Performs hand gesture recognition on the given image. +// +// TODO add the link to DevSite. +// This API expects a pre-trained hand gesture model asset bundle, or a custom +// one created using Model Maker. See . +// +// Inputs: +// Image +// - The image that gesture recognition runs on. +// std::optional +// - If provided, can be used to specify the rotation to apply to the image +// before performing gesture recognition, by setting its 'rotation' field +// in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note +// that specifying a region-of-interest using the 'x_center', 'y_center', +// 'width' and 'height' fields is NOT supported and will result in an +// invalid argument error being returned. +// Outputs: +// GestureRecognitionResult +// - The hand gesture recognition results. +class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates a GestureRecognizer from a GestureRecognizerhOptions to process + // image data or streaming data. Gesture recognizer can be created with one of + // the following three running modes: + // 1) Image mode for recognizing gestures on single image inputs. + // Users provide mediapipe::Image to the `Recognize` method, and will + // receive the recognized hand gesture results as the return value. + // 2) Video mode for recognizing gestures on the decoded frames of a video. + // 3) Live stream mode for recognizing gestures on the live stream of the + // input data, such as from camera. Users call `RecognizeAsync` to push the + // image data into the GestureRecognizer, the recognized results along with + // the input timestamp and the image that gesture recognizer runs on will + // be available in the result callback when the gesture recognizer finishes + // the work. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs hand gesture recognition on the given image. + // Only use this method when the GestureRecognizer is created with the image + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. + // TODO: Describes how the input image will be preprocessed + // after the yuv support is implemented. + absl::StatusOr Recognize( + Image image, + std::optional image_processing_options = + std::nullopt); + + // Performs gesture recognition on the provided video frame. + // Only use this method when the GestureRecognizer is created with the video + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + absl::StatusOr + RecognizeForVideo(Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); + + // Sends live image data to perform gesture recognition, and the results will + // be available via the "result_callback" provided in the + // GestureRecognizerOptions. Only use this method when the GestureRecognizer + // is created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the gesture recognizer. The input timestamps must be monotonically + // increasing. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing recognition, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // The "result_callback" provides + // - A vector of GestureRecognitionResult, each is the recognized results + // for a input frame. + // - The const reference to the corresponding input image that the gesture + // recognizer runs on. Note that the const reference to the image will no + // longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status RecognizeAsync(Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); + + // Shuts down the GestureRecognizer when all works are done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace gesture_recognizer +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_ diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc new file mode 100644 index 000000000..7ab4847dd --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -0,0 +1,292 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace gesture_recognizer { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; +using ::mediapipe::tasks::vision::gesture_recognizer::proto:: + GestureRecognizerGraphOptions; +using ::mediapipe::tasks::vision::gesture_recognizer::proto:: + HandGestureRecognizerGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kHandGesturesTag[] = "HAND_GESTURES"; +constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; +constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task"; +constexpr char kHandGestureRecognizerBundleAssetName[] = + "hand_gesture_recognizer.task"; + +struct GestureRecognizerOutputs { + Source> gesture; + Source> handedness; + Source> hand_landmarks; + Source> hand_world_landmarks; + Source image; +}; + +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + GestureRecognizerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto hand_landmarker_file, + resources.GetModelFile(kHandLandmarkerBundleAssetName)); + auto* hand_landmarker_graph_options = + options->mutable_hand_landmarker_graph_options(); + SetExternalFile(hand_landmarker_file, + hand_landmarker_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_landmarker_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + + ASSIGN_OR_RETURN( + const auto hand_gesture_recognizer_file, + resources.GetModelFile(kHandGestureRecognizerBundleAssetName)); + auto* hand_gesture_recognizer_graph_options = + options->mutable_hand_gesture_recognizer_graph_options(); + SetExternalFile(hand_gesture_recognizer_file, + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + if (!hand_gesture_recognizer_graph_options->base_options() + .acceleration() + .has_xnnpack() && + !hand_gesture_recognizer_graph_options->base_options() + .acceleration() + .has_tflite()) { + hand_gesture_recognizer_graph_options->mutable_base_options() + ->mutable_acceleration() + ->mutable_xnnpack(); + LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets " + << "HandGestureRecognizerGraph acceleartion to Xnnpack."; + } + hand_gesture_recognizer_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + return absl::OkStatus(); +} + +} // namespace + +// A "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" performs +// hand gesture recognition. +// +// Inputs: +// IMAGE - Image +// Image to perform hand gesture recognition on. +// NORM_RECT - NormalizedRect +// Describes image rotation and region of image to perform landmarks +// detection on. +// +// Outputs: +// HAND_GESTURES - std::vector +// Recognized hand gestures with sorted order such that the winning label is +// the first item in the list. +// LANDMARKS: - std::vector +// Detected hand landmarks. +// WORLD_LANDMARKS - std::vector +// Detected hand landmarks in world coordinates. +// HAND_RECT_NEXT_FRAME - std::vector +// The predicted Rect enclosing the hand RoI for landmark detection on the +// next frame. +// HANDEDNESS - std::vector +// Classification of handedness. +// IMAGE - mediapipe::Image +// The image that gesture recognizer runs on and has the pixel data stored +// on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. +// +// Example: +// node { +// calculator: +// "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph" +// input_stream: "IMAGE:image_in" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "HAND_GESTURES:hand_gestures" +// output_stream: "LANDMARKS:hand_landmarks" +// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" +// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" +// output_stream: "HANDEDNESS:handedness" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.vision.gesture_recognizer.proto.GestureRecognizerGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "hand_gesture.tflite" +// } +// } +// hand_landmark_detector_options { +// base_options { +// model_asset { +// file_name: "hand_landmark.tflite" +// } +// } +// } +// } +// } +// } +class GestureRecognizerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources(sc)); + // When the model resources cache service is available, filling in + // the file pointer meta in the subtasks' base options. Otherwise, + // providing the file contents instead. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } + ASSIGN_OR_RETURN(auto hand_gesture_recognition_output, + BuildGestureRecognizerGraph( + *sc->MutableOptions(), + graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); + hand_gesture_recognition_output.gesture >> + graph[Output>(kHandGesturesTag)]; + hand_gesture_recognition_output.handedness >> + graph[Output>(kHandednessTag)]; + hand_gesture_recognition_output.hand_landmarks >> + graph[Output>(kLandmarksTag)]; + hand_gesture_recognition_output.hand_world_landmarks >> + graph[Output>(kWorldLandmarksTag)]; + hand_gesture_recognition_output.image >> graph[Output(kImageTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr BuildGestureRecognizerGraph( + GestureRecognizerGraphOptions& graph_options, Source image_in, + Source norm_rect_in, Graph& graph) { + auto& image_property = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_property.In("IMAGE"); + auto image_size = image_property.Out("SIZE"); + + // Hand landmarker graph. + auto& hand_landmarker_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"); + auto& hand_landmarker_graph_options = + hand_landmarker_graph.GetOptions(); + hand_landmarker_graph_options.Swap( + graph_options.mutable_hand_landmarker_graph_options()); + + image_in >> hand_landmarker_graph.In(kImageTag); + norm_rect_in >> hand_landmarker_graph.In(kNormRectTag); + auto hand_landmarks = + hand_landmarker_graph[Output>( + kLandmarksTag)]; + auto hand_world_landmarks = + hand_landmarker_graph[Output>( + kWorldLandmarksTag)]; + auto handedness = + hand_landmarker_graph[Output>( + kHandednessTag)]; + + auto& vector_indices = + graph.AddNode("NormalizedLandmarkListVectorIndicesCalculator"); + hand_landmarks >> vector_indices.In("VECTOR"); + auto hand_landmarks_id = vector_indices.Out("INDICES"); + + // Hand gesture recognizer subgraph. + auto& hand_gesture_subgraph = graph.AddNode( + "mediapipe.tasks.vision.gesture_recognizer." + "MultipleHandGestureRecognizerGraph"); + hand_gesture_subgraph.GetOptions().Swap( + graph_options.mutable_hand_gesture_recognizer_graph_options()); + hand_landmarks >> hand_gesture_subgraph.In(kLandmarksTag); + hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag); + handedness >> hand_gesture_subgraph.In(kHandednessTag); + image_size >> hand_gesture_subgraph.In(kImageSizeTag); + norm_rect_in >> hand_gesture_subgraph.In(kNormRectTag); + hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag); + auto hand_gestures = + hand_gesture_subgraph[Output>( + kHandGesturesTag)]; + + return {{.gesture = hand_gestures, + .handedness = handedness, + .hand_landmarks = hand_landmarks, + .hand_world_landmarks = hand_world_landmarks, + .image = hand_landmarker_graph[Output(kImageTag)]}}; + } +}; + +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::gesture_recognizer::GestureRecognizerGraph); // NOLINT +// clang-format on + +} // namespace gesture_recognizer +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc new file mode 100644 index 000000000..7b7746956 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -0,0 +1,494 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.pb.h" +#include "mediapipe/tasks/metadata/metadata_schema_generated.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace gesture_recognizer { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::processors:: + ConfigureTensorsToClassificationCalculator; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; +using ::mediapipe::tasks::vision::gesture_recognizer::proto:: + HandGestureRecognizerGraphOptions; + +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; +constexpr char kHandGesturesTag[] = "HAND_GESTURES"; +constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX"; +constexpr char kCloneTag[] = "CLONE"; +constexpr char kItemTag[] = "ITEM"; +constexpr char kVectorTag[] = "VECTOR"; +constexpr char kIndexTag[] = "INDEX"; +constexpr char kIterableTag[] = "ITERABLE"; +constexpr char kBatchEndTag[] = "BATCH_END"; +constexpr char kGestureEmbedderTFLiteName[] = "gesture_embedder.tflite"; +constexpr char kCannedGestureClassifierTFLiteName[] = + "canned_gesture_classifier.tflite"; + +struct SubTaskModelResources { + const core::ModelResources* gesture_embedder_model_resource; + const core::ModelResources* canned_gesture_classifier_model_resource; +}; + +Source> ConvertMatrixToTensor(Source matrix, + Graph& graph) { + auto& node = graph.AddNode("TensorConverterCalculator"); + matrix >> node.In("MATRIX"); + return node[Output>{"TENSORS"}]; +} + +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + HandGestureRecognizerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto gesture_embedder_file, + resources.GetModelFile(kGestureEmbedderTFLiteName)); + auto* gesture_embedder_graph_options = + options->mutable_gesture_embedder_graph_options(); + SetExternalFile(gesture_embedder_file, + gesture_embedder_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + gesture_embedder_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + gesture_embedder_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + + ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file, + resources.GetModelFile(kCannedGestureClassifierTFLiteName)); + auto* canned_gesture_classifier_graph_options = + options->mutable_canned_gesture_classifier_graph_options(); + SetExternalFile( + canned_gesture_classifier_file, + canned_gesture_classifier_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + canned_gesture_classifier_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + canned_gesture_classifier_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + return absl::OkStatus(); +} + +} // namespace + +// A +// "mediapipe.tasks.vision.gesture_recognizer.SingleHandGestureRecognizerGraph" +// performs single hand gesture recognition. This graph is used as a building +// block for mediapipe.tasks.vision.GestureRecognizerGraph. +// +// Inputs: +// HANDEDNESS - ClassificationList +// Classification of handedness. +// LANDMARKS - NormalizedLandmarkList +// Detected hand landmarks in normalized image coordinates. +// WORLD_LANDMARKS - LandmarkList +// Detected hand landmarks in world coordinates. +// IMAGE_SIZE - std::pair +// The size of image from which the landmarks detected from. +// NORM_RECT - NormalizedRect +// NormalizedRect whose 'rotation' field is used to rotate the +// landmarks before processing them. +// +// Outputs: +// HAND_GESTURES - ClassificationList +// Recognized hand gestures with sorted order such that the winning label is +// the first item in the list. +// +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerGraph" +// input_stream: "HANDEDNESS:handedness" +// input_stream: "LANDMARKS:landmarks" +// input_stream: "WORLD_LANDMARKS:world_landmarks" +// input_stream: "IMAGE_SIZE:image_size" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "HAND_GESTURES:hand_gestures" +// options { +// [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "hand_gesture.tflite" +// } +// } +// } +// } +// } +class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources( + sc)); + // When the model resources cache service is available, filling in + // the file pointer meta in the subtasks' base options. Otherwise, + // providing the file contents instead. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } + ASSIGN_OR_RETURN(const auto sub_task_model_resources, + CreateSubTaskModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN(auto hand_gestures, + BuildGestureRecognizerGraph( + sc->Options(), + sub_task_model_resources, + graph[Input(kHandednessTag)], + graph[Input(kLandmarksTag)], + graph[Input(kWorldLandmarksTag)], + graph[Input>(kImageSizeTag)], + graph[Input(kNormRectTag)], graph)); + hand_gestures >> graph[Output(kHandGesturesTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr CreateSubTaskModelResources( + SubgraphContext* sc) { + auto* options = sc->MutableOptions(); + SubTaskModelResources sub_task_model_resources; + auto& gesture_embedder_model_asset = + *options->mutable_gesture_embedder_graph_options() + ->mutable_base_options() + ->mutable_model_asset(); + ASSIGN_OR_RETURN( + sub_task_model_resources.gesture_embedder_model_resource, + CreateModelResources(sc, + std::make_unique( + std::move(gesture_embedder_model_asset)), + "_gesture_embedder")); + auto& canned_gesture_classifier_model_asset = + *options->mutable_canned_gesture_classifier_graph_options() + ->mutable_base_options() + ->mutable_model_asset(); + ASSIGN_OR_RETURN( + sub_task_model_resources.canned_gesture_classifier_model_resource, + CreateModelResources( + sc, + std::make_unique( + std::move(canned_gesture_classifier_model_asset)), + "_canned_gesture_classifier")); + return sub_task_model_resources; + } + + absl::StatusOr> BuildGestureRecognizerGraph( + const HandGestureRecognizerGraphOptions& graph_options, + const SubTaskModelResources& sub_task_model_resources, + Source handedness, + Source hand_landmarks, + Source hand_world_landmarks, + Source> image_size, Source norm_rect, + Graph& graph) { + // Converts the ClassificationList to a matrix. + auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator"); + handedness >> handedness_to_matrix.In(kHandednessTag); + auto handedness_matrix = + handedness_to_matrix[Output(kHandednessMatrixTag)]; + + // Converts the handedness matrix to a tensor for the inference + // calculator. + auto handedness_tensors = ConvertMatrixToTensor(handedness_matrix, graph); + + // Converts the screen landmarks to a matrix. + LandmarksToMatrixCalculatorOptions landmarks_options; + landmarks_options.set_object_normalization(true); + landmarks_options.set_object_normalization_origin_offset(0); + auto& hand_landmarks_to_matrix = + graph.AddNode("LandmarksToMatrixCalculator"); + hand_landmarks_to_matrix.GetOptions() = + landmarks_options; + hand_landmarks >> hand_landmarks_to_matrix.In(kLandmarksTag); + image_size >> hand_landmarks_to_matrix.In(kImageSizeTag); + norm_rect >> hand_landmarks_to_matrix.In(kNormRectTag); + auto hand_landmarks_matrix = + hand_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; + + // Converts the landmarks matrix to a tensor for the inference calculator. + auto hand_landmarks_tensor = + ConvertMatrixToTensor(hand_landmarks_matrix, graph); + + // Converts the world landmarks to a matrix. + auto& hand_world_landmarks_to_matrix = + graph.AddNode("LandmarksToMatrixCalculator"); + hand_world_landmarks_to_matrix + .GetOptions() = landmarks_options; + hand_world_landmarks >> + hand_world_landmarks_to_matrix.In(kWorldLandmarksTag); + image_size >> hand_world_landmarks_to_matrix.In(kImageSizeTag); + norm_rect >> hand_world_landmarks_to_matrix.In(kNormRectTag); + auto hand_world_landmarks_matrix = + hand_world_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; + + // Converts the world landmarks matrix to a tensor for the inference + // calculator. + auto hand_world_landmarks_tensor = + ConvertMatrixToTensor(hand_world_landmarks_matrix, graph); + + // Converts a tensor into a vector of tensors for the inference + // calculator. + auto& concatenate_tensor_vector = + graph.AddNode("ConcatenateTensorVectorCalculator"); + hand_landmarks_tensor >> concatenate_tensor_vector.In(0); + handedness_tensors >> concatenate_tensor_vector.In(1); + hand_world_landmarks_tensor >> concatenate_tensor_vector.In(2); + auto concatenated_tensors = concatenate_tensor_vector.Out(""); + + // Inference for static hand gesture recognition. + auto& gesture_embedder_inference = + AddInference(*sub_task_model_resources.gesture_embedder_model_resource, + graph_options.gesture_embedder_graph_options() + .base_options() + .acceleration(), + graph); + concatenated_tensors >> gesture_embedder_inference.In(kTensorsTag); + auto embedding_tensors = gesture_embedder_inference.Out(kTensorsTag); + + auto& canned_gesture_classifier_inference = AddInference( + *sub_task_model_resources.canned_gesture_classifier_model_resource, + graph_options.canned_gesture_classifier_graph_options() + .base_options() + .acceleration(), + graph); + embedding_tensors >> canned_gesture_classifier_inference.In(kTensorsTag); + auto inference_output_tensors = + canned_gesture_classifier_inference.Out(kTensorsTag); + + auto& tensors_to_classification = + graph.AddNode("TensorsToClassificationCalculator"); + MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( + graph_options.canned_gesture_classifier_graph_options() + .classifier_options(), + *sub_task_model_resources.canned_gesture_classifier_model_resource + ->GetMetadataExtractor(), + 0, + &tensors_to_classification.GetOptions< + mediapipe::TensorsToClassificationCalculatorOptions>())); + inference_output_tensors >> tensors_to_classification.In(kTensorsTag); + auto classification_list = + tensors_to_classification[Output( + "CLASSIFICATIONS")]; + return classification_list; + } +}; + +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::gesture_recognizer::SingleHandGestureRecognizerGraph); // NOLINT +// clang-format on + +// A +// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph" +// performs multi hand gesture recognition. This graph is used as a building +// block for mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph. +// +// Inputs: +// HANDEDNESS - std::vector +// A vector of Classification of handedness. +// LANDMARKS - std::vector +// A vector hand landmarks in normalized image coordinates. +// WORLD_LANDMARKS - std::vector +// A vector hand landmarks in world coordinates. +// IMAGE_SIZE - std::pair +// The size of image from which the landmarks detected from. +// NORM_RECT - NormalizedRect +// NormalizedRect whose 'rotation' field is used to rotate the +// landmarks before processing them. +// HAND_TRACKING_IDS - std::vector +// A vector of the tracking ids of the hands. The tracking id is the vector +// index corresponding to the same hand if the graph runs multiple times. +// +// Outputs: +// HAND_GESTURES - std::vector +// A vector of recognized hand gestures. Each vector element is the +// ClassificationList of the hand in input vector. +// +// +// Example: +// node { +// calculator: +// "mediapipe.tasks.vision.gesture_recognizer.MultipleHandGestureRecognizerGraph" +// input_stream: "HANDEDNESS:handedness" +// input_stream: "LANDMARKS:landmarks" +// input_stream: "WORLD_LANDMARKS:world_landmarks" +// input_stream: "IMAGE_SIZE:image_size" +// input_stream: "NORM_RECT:norm_rect" +// input_stream: "HAND_TRACKING_IDS:hand_tracking_ids" +// output_stream: "HAND_GESTURES:hand_gestures" +// options { +// [mediapipe.tasks.vision.gesture_recognizer.proto.MultipleHandGestureRecognizerGraph.ext] +// { +// base_options { +// model_asset { +// file_name: "hand_gesture.tflite" +// } +// } +// } +// } +// } +class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + ASSIGN_OR_RETURN( + auto multi_hand_gestures, + BuildMultiGestureRecognizerSubraph( + sc->Options(), + graph[Input>(kHandednessTag)], + graph[Input>(kLandmarksTag)], + graph[Input>(kWorldLandmarksTag)], + graph[Input>(kImageSizeTag)], + graph[Input(kNormRectTag)], + graph[Input>(kHandTrackingIdsTag)], graph)); + multi_hand_gestures >> + graph[Output>(kHandGesturesTag)]; + return graph.GetConfig(); + } + + private: + absl::StatusOr>> + BuildMultiGestureRecognizerSubraph( + const HandGestureRecognizerGraphOptions& graph_options, + Source> multi_handedness, + Source> multi_hand_landmarks, + Source> multi_hand_world_landmarks, + Source> image_size, Source norm_rect, + Source> multi_hand_tracking_ids, Graph& graph) { + auto& begin_loop_int = graph.AddNode("BeginLoopIntCalculator"); + image_size >> begin_loop_int.In(kCloneTag)[0]; + norm_rect >> begin_loop_int.In(kCloneTag)[1]; + multi_handedness >> begin_loop_int.In(kCloneTag)[2]; + multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[3]; + multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[4]; + multi_hand_tracking_ids >> begin_loop_int.In(kIterableTag); + auto image_size_clone = begin_loop_int.Out(kCloneTag)[0]; + auto norm_rect_clone = begin_loop_int.Out(kCloneTag)[1]; + auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[2]; + auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[3]; + auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[4]; + auto hand_tracking_id = begin_loop_int.Out(kItemTag); + auto batch_end = begin_loop_int.Out(kBatchEndTag); + + auto& get_handedness_at_index = + graph.AddNode("GetClassificationListVectorItemCalculator"); + multi_handedness_clone >> get_handedness_at_index.In(kVectorTag); + hand_tracking_id >> get_handedness_at_index.In(kIndexTag); + auto handedness = get_handedness_at_index.Out(kItemTag); + + auto& get_landmarks_at_index = + graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator"); + multi_hand_landmarks_clone >> get_landmarks_at_index.In(kVectorTag); + hand_tracking_id >> get_landmarks_at_index.In(kIndexTag); + auto hand_landmarks = get_landmarks_at_index.Out(kItemTag); + + auto& get_world_landmarks_at_index = + graph.AddNode("GetLandmarkListVectorItemCalculator"); + multi_hand_world_landmarks_clone >> + get_world_landmarks_at_index.In(kVectorTag); + hand_tracking_id >> get_world_landmarks_at_index.In(kIndexTag); + auto hand_world_landmarks = get_world_landmarks_at_index.Out(kItemTag); + + auto& hand_gesture_recognizer_graph = graph.AddNode( + "mediapipe.tasks.vision.gesture_recognizer." + "SingleHandGestureRecognizerGraph"); + hand_gesture_recognizer_graph + .GetOptions() + .CopyFrom(graph_options); + handedness >> hand_gesture_recognizer_graph.In(kHandednessTag); + hand_landmarks >> hand_gesture_recognizer_graph.In(kLandmarksTag); + hand_world_landmarks >> + hand_gesture_recognizer_graph.In(kWorldLandmarksTag); + image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag); + norm_rect_clone >> hand_gesture_recognizer_graph.In(kNormRectTag); + auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag); + + auto& end_loop_classification_lists = + graph.AddNode("EndLoopClassificationListCalculator"); + batch_end >> end_loop_classification_lists.In(kBatchEndTag); + hand_gestures >> end_loop_classification_lists.In(kItemTag); + auto multi_hand_gestures = + end_loop_classification_lists[Output>( + kIterableTag)]; + + return multi_hand_gestures; + } +}; + +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::gesture_recognizer::MultipleHandGestureRecognizerGraph); // NOLINT +// clang-format on + +} // namespace gesture_recognizer +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc similarity index 93% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc index 00e19cdb5..60ccae92c 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" #include @@ -25,6 +25,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace {} // namespace @@ -58,6 +59,7 @@ absl::StatusOr GetLeftHandScore( } } +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h similarity index 79% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h index 74e04b8cc..ae4137d0f 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ -#define MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#ifndef MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ #include "absl/status/statusor.h" #include "mediapipe/framework/formats/classification.pb.h" @@ -22,6 +22,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { bool IsLeftHand(const mediapipe::Classification& c); @@ -30,8 +31,9 @@ bool IsRightHand(const mediapipe::Classification& c); absl::StatusOr GetLeftHandScore( const mediapipe::ClassificationList& classification_list); +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe -#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ +#endif // MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZER_HADNDEDNESS_UTILS_H_ diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc similarity index 94% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc rename to mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc index 51dfb5dea..40a201ae8 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util_test.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/handedness_util.h" +#include "mediapipe/tasks/cc/vision/gesture_recognizer/handedness_util.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/port/gmock.h" @@ -23,6 +23,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace gesture_recognizer { namespace { TEST(GetLeftHandScore, SingleLeftHandClassification) { @@ -72,6 +73,7 @@ TEST(GetLeftHandScore, LeftAndRightLowerCaseHandClassification) { } } // namespace +} // namespace gesture_recognizer } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD new file mode 100644 index 000000000..0db47da7a --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/BUILD @@ -0,0 +1,66 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "gesture_embedder_graph_options_proto", + srcs = ["gesture_embedder_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "gesture_classifier_graph_options_proto", + srcs = ["gesture_classifier_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "hand_gesture_recognizer_graph_options_proto", + srcs = ["hand_gesture_recognizer_graph_options.proto"], + deps = [ + ":gesture_classifier_graph_options_proto", + ":gesture_embedder_graph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) + +mediapipe_proto_library( + name = "gesture_recognizer_graph_options_proto", + srcs = ["gesture_recognizer_graph_options.proto"], + deps = [ + ":hand_gesture_recognizer_graph_options_proto", + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto similarity index 61% rename from mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto rename to mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto index f73443eaf..dcefa075f 100644 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto @@ -12,28 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO Refactor naming and class structure of hand related Tasks. + syntax = "proto2"; -package mediapipe.tasks.vision.hand_gesture_recognizer.proto; +package mediapipe.tasks.vision.gesture_recognizer.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message HandGestureRecognizerSubgraphOptions { +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; +option java_outer_classname = "GestureClassifierGraphOptionsProto"; + +message GestureClassifierGraphOptions { extend mediapipe.CalculatorOptions { - optional HandGestureRecognizerSubgraphOptions ext = 463370452; + optional GestureClassifierGraphOptions ext = 478825465; } // Base options for configuring hand gesture recognition subgraph, such as // specifying the TfLite model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // Options for configuring the gesture classifier behavior, such as score - // threshold, number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; - - // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be - // considered tracked successfully - optional float min_tracking_confidence = 3 [default = 0.0]; + optional components.processors.proto.ClassifierOptions classifier_options = 2; } diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto new file mode 100644 index 000000000..bff4e0a9c --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto @@ -0,0 +1,33 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.gesture_recognizer.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; +option java_outer_classname = "GestureEmbedderGraphOptionsProto"; + +message GestureEmbedderGraphOptions { + extend mediapipe.CalculatorOptions { + optional GestureEmbedderGraphOptions ext = 478825422; + } + // Base options for configuring hand gesture recognition subgraph, such as + // specifying the TfLite model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto new file mode 100644 index 000000000..57d8a3746 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_recognizer_graph_options.proto @@ -0,0 +1,43 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.gesture_recognizer.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; +option java_outer_classname = "GestureRecognizerGraphOptionsProto"; + +message GestureRecognizerGraphOptions { + extend mediapipe.CalculatorOptions { + optional GestureRecognizerGraphOptions ext = 479097054; + } + // Base options for configuring gesture recognizer graph, such as specifying + // the TfLite model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for configuring hand landmarker graph. + optional hand_landmarker.proto.HandLandmarkerGraphOptions + hand_landmarker_graph_options = 2; + + // Options for configuring hand gesture recognizer graph. + optional HandGestureRecognizerGraphOptions + hand_gesture_recognizer_graph_options = 3; +} diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto new file mode 100644 index 000000000..7df2fed37 --- /dev/null +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/proto/hand_gesture_recognizer_graph_options.proto @@ -0,0 +1,46 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// TODO Refactor naming and class structure of hand related Tasks. +syntax = "proto2"; + +package mediapipe.tasks.vision.gesture_recognizer.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_classifier_graph_options.proto"; +import "mediapipe/tasks/cc/vision/gesture_recognizer/proto/gesture_embedder_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.gesturerecognizer.proto"; +option java_outer_classname = "HandGestureRecognizerGraphOptionsProto"; + +message HandGestureRecognizerGraphOptions { + extend mediapipe.CalculatorOptions { + optional HandGestureRecognizerGraphOptions ext = 463370452; + } + // Base options for configuring hand gesture recognition subgraph, such as + // specifying the TfLite model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for GestureEmbedder. + optional GestureEmbedderGraphOptions gesture_embedder_graph_options = 2; + + // Options for GestureClassifier of canned gestures. + optional GestureClassifierGraphOptions + canned_gesture_classifier_graph_options = 3; + + // Options for GestureClassifier of custom gestures. + optional GestureClassifierGraphOptions + custom_gesture_classifier_graph_options = 4; +} diff --git a/mediapipe/tasks/cc/vision/hand_detector/BUILD b/mediapipe/tasks/cc/vision/hand_detector/BUILD index 23cf5f72d..71cef6270 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/BUILD @@ -18,18 +18,6 @@ package(default_visibility = [ licenses(["notice"]) -cc_library( - name = "hand_detector_op_resolver", - srcs = ["hand_detector_op_resolver.cc"], - hdrs = ["hand_detector_op_resolver.h"], - deps = [ - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - cc_library( name = "hand_detector_graph", srcs = ["hand_detector_graph.cc"], @@ -44,7 +32,7 @@ cc_library( "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", "//mediapipe/calculators/util:detection_label_id_to_text_calculator", "//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto", - "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", "//mediapipe/calculators/util:non_max_suppression_calculator", @@ -63,7 +51,7 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc index 7ead21bad..e876d7d09 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph.cc @@ -40,12 +40,13 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" namespace mediapipe { namespace tasks { namespace vision { +namespace hand_detector { namespace { @@ -53,18 +54,24 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; constexpr char kImageTag[] = "IMAGE"; -constexpr char kDetectionsTag[] = "DETECTIONS"; -constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; +constexpr char kHandRectsTag[] = "HAND_RECTS"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; struct HandDetectionOuts { Source> palm_detections; Source> hand_rects; + Source> palm_rects; + Source image; }; void ConfigureTensorsToDetectionsCalculator( + const HandDetectorGraphOptions& tasks_options, mediapipe::TensorsToDetectionsCalculatorOptions* options) { // TODO use metadata to configure these fields. options->set_num_classes(1); @@ -77,7 +84,7 @@ void ConfigureTensorsToDetectionsCalculator( options->set_sigmoid_score(true); options->set_score_clipping_thresh(100.0); options->set_reverse_output_order(true); - options->set_min_score_thresh(0.5); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); options->set_x_scale(192.0); options->set_y_scale(192.0); options->set_w_scale(192.0); @@ -134,29 +141,43 @@ void ConfigureRectTransformationCalculator( } // namespace -// A "mediapipe.tasks.vision.HandDetectorGraph" performs hand detection. The -// Hand Detection Graph is based on palm detection model, and scale the detected -// palm bounding box to enclose the detected whole hand. +// A "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" performs hand +// detection. The Hand Detection Graph is based on palm detection model, and +// scale the detected palm bounding box to enclose the detected whole hand. // Accepts CPU input images and outputs Landmark on CPU. // // Inputs: // IMAGE - Image // Image to perform detection on. +// NORM_RECT - NormalizedRect +// Describes image rotation and region of image to perform detection +// on. // // Outputs: -// DETECTIONS - std::vector +// PALM_DETECTIONS - std::vector // Detected palms with maximum `num_hands` specified in options. -// NORM_RECTS - std::vector +// HAND_RECTS - std::vector // Detected hand bounding boxes in normalized coordinates. +// PLAM_RECTS - std::vector +// Detected palm bounding boxes in normalized coordinates. +// IMAGE - Image +// The input image that the hand detector runs on and has the pixel data +// stored on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandDetectorGraph" +// calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph" // input_stream: "IMAGE:image" -// output_stream: "DETECTIONS:palm_detections" -// output_stream: "NORM_RECTS:hand_rects_from_palm_detections" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "PALM_DETECTIONS:palm_detections" +// output_stream: "HAND_RECTS:hand_rects_from_palm_detections" +// output_stream: "PALM_RECTS:palm_rects" +// output_stream: "IMAGE:image_out" // options { -// [mediapipe.tasks.hand_detector.proto.HandDetectorOptions.ext] { +// [mediapipe.tasks.vision.hand_detector.proto.HandDetectorGraphOptions.ext] +// { // base_options { // model_asset { // file_name: "palm_detection.tflite" @@ -173,16 +194,20 @@ class HandDetectorGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN(auto hand_detection_outs, BuildHandDetectionSubgraph( - sc->Options(), *model_resources, - graph[Input(kImageTag)], graph)); + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); hand_detection_outs.palm_detections >> - graph[Output>(kDetectionsTag)]; + graph[Output>(kPalmDetectionsTag)]; hand_detection_outs.hand_rects >> - graph[Output>(kNormRectsTag)]; + graph[Output>(kHandRectsTag)]; + hand_detection_outs.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_detection_outs.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -196,9 +221,9 @@ class HandDetectorGraph : public core::ModelTaskGraph { // image_in: image stream to run hand detection on. // graph: the mediapipe builder::Graph instance to be updated. absl::StatusOr BuildHandDetectionSubgraph( - const HandDetectorOptions& subgraph_options, + const HandDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { // Add image preprocessing subgraph. The model expects aspect ratio // unchanged. auto& preprocessing = @@ -215,8 +240,9 @@ class HandDetectorGraph : public core::ModelTaskGraph { &preprocessing .GetOptions())); image_in >> preprocessing.In("IMAGE"); + norm_rect_in >> preprocessing.In("NORM_RECT"); auto preprocessed_tensors = preprocessing.Out("TENSORS"); - auto letterbox_padding = preprocessing.Out("LETTERBOX_PADDING"); + auto matrix = preprocessing.Out("MATRIX"); auto image_size = preprocessing.Out("IMAGE_SIZE"); // Adds SSD palm detection model. @@ -235,6 +261,7 @@ class HandDetectorGraph : public core::ModelTaskGraph { auto& tensors_to_detections = graph.AddNode("TensorsToDetectionsCalculator"); ConfigureTensorsToDetectionsCalculator( + subgraph_options, &tensors_to_detections .GetOptions()); model_output_tensors >> tensors_to_detections.In("TENSORS"); @@ -259,17 +286,12 @@ class HandDetectorGraph : public core::ModelTaskGraph { nms_detections >> detection_label_id_to_text.In(""); auto detections_with_text = detection_label_id_to_text.Out(""); - // Adjusts detection locations (already normalized to [0.f, 1.f]) on the - // letterboxed image (after image transformation with the FIT scale mode) to - // the corresponding locations on the same image with the letterbox removed - // (the input image to the graph before image transformation). - auto& detection_letterbox_removal = - graph.AddNode("DetectionLetterboxRemovalCalculator"); - detections_with_text >> detection_letterbox_removal.In("DETECTIONS"); - letterbox_padding >> detection_letterbox_removal.In("LETTERBOX_PADDING"); + // Projects detections back into the input image coordinates system. + auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); + detections_with_text >> detection_projection.In("DETECTIONS"); + matrix >> detection_projection.In("PROJECTION_MATRIX"); auto palm_detections = - detection_letterbox_removal[Output>( - "DETECTIONS")]; + detection_projection[Output>("DETECTIONS")]; // Converts each palm detection into a rectangle (normalized by image size) // that encloses the palm and is rotated such that the line connecting @@ -281,7 +303,8 @@ class HandDetectorGraph : public core::ModelTaskGraph { .GetOptions()); palm_detections >> detections_to_rects.In("DETECTIONS"); image_size >> detections_to_rects.In("IMAGE_SIZE"); - auto palm_rects = detections_to_rects.Out("NORM_RECTS"); + auto palm_rects = + detections_to_rects[Output>("NORM_RECTS")]; // Expands and shifts the rectangle that contains the palm so that it's // likely to cover the entire hand. @@ -308,13 +331,18 @@ class HandDetectorGraph : public core::ModelTaskGraph { clip_normalized_rect_vector_size[Output>( "")]; - return HandDetectionOuts{.palm_detections = palm_detections, - .hand_rects = clipped_hand_rects}; + return HandDetectionOuts{ + /* palm_detections= */ palm_detections, + /* hand_rects= */ clipped_hand_rects, + /* palm_rects= */ palm_rects, + /* image= */ preprocessing[Output(kImageTag)]}; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandDetectorGraph); +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_detector::HandDetectorGraph); +} // namespace hand_detector } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc index a2fbd7c54..cbbc0e193 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_graph_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -35,18 +36,19 @@ limitations under the License. #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" -#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" namespace mediapipe { namespace tasks { namespace vision { +namespace hand_detector { namespace { using ::file::Defaults; @@ -60,7 +62,8 @@ using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::TaskRunner; using ::mediapipe::tasks::core::proto::ExternalFile; using ::mediapipe::tasks::vision::DecodeImageFromFile; -using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorOptions; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; using ::mediapipe::tasks::vision::hand_detector::proto::HandDetectorResult; using ::testing::EqualsProto; using ::testing::TestParamInfo; @@ -73,16 +76,21 @@ using ::testing::proto::Partially; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite"; constexpr char kTestRightHandsImage[] = "right_hands.jpg"; +constexpr char kTestRightHandsRotatedImage[] = "right_hands_rotated.jpg"; constexpr char kTestModelResourcesTag[] = "test_model_resources"; constexpr char kOneHandResultFile[] = "hand_detector_result_one_hand.pbtxt"; +constexpr char kOneHandRotatedResultFile[] = + "hand_detector_result_one_hand_rotated.pbtxt"; constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt"; constexpr char kImageTag[] = "IMAGE"; constexpr char kImageName[] = "image"; -constexpr char kPalmDetectionsTag[] = "DETECTIONS"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; constexpr char kPalmDetectionsName[] = "palm_detections"; -constexpr char kHandNormRectsTag[] = "NORM_RECTS"; +constexpr char kHandRectsTag[] = "HAND_RECTS"; constexpr char kHandNormRectsName[] = "hand_norm_rects"; constexpr float kPalmDetectionBboxMaxDiff = 0.01; @@ -104,25 +112,27 @@ absl::StatusOr> CreateTaskRunner( Graph graph; auto& hand_detection = - graph.AddNode("mediapipe.tasks.vision.HandDetectorGraph"); + graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); options->set_min_detection_confidence(0.5); options->set_num_hands(num_hands); - hand_detection.GetOptions().Swap(options.get()); + hand_detection.GetOptions().Swap(options.get()); graph[Input(kImageTag)].SetName(kImageName) >> hand_detection.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + hand_detection.In(kNormRectTag); hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >> graph[Output>(kPalmDetectionsTag)]; - hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> - graph[Output>(kHandNormRectsTag)]; + hand_detection.Out(kHandRectsTag).SetName(kHandNormRectsName) >> + graph[Output>(kHandRectsTag)]; - return TaskRunner::Create(graph.GetConfig(), - absl::make_unique()); + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); } HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) { @@ -140,6 +150,9 @@ struct TestParams { std::string hand_detection_model_name; // The filename of test image. std::string test_image_name; + // The rotation to apply to the test image before processing, in radians + // counter-clockwise. + float rotation; // The number of maximum detected hands. int num_hands; // The expected hand detector result. @@ -152,14 +165,22 @@ TEST_P(HandDetectionTest, DetectTwoHands) { MP_ASSERT_OK_AND_ASSIGN( Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, GetParam().test_image_name))); + NormalizedRect input_norm_rect; + input_norm_rect.set_rotation(GetParam().rotation); + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); MP_ASSERT_OK_AND_ASSIGN( auto model_resources, CreateModelResourcesForModel(GetParam().hand_detection_model_name)); MP_ASSERT_OK_AND_ASSIGN( auto task_runner, CreateTaskRunner(*model_resources, kPalmDetectionModel, GetParam().num_hands)); - auto output_packets = - task_runner->Process({{kImageName, MakePacket(std::move(image))}}); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); MP_ASSERT_OK(output_packets); const std::vector& palm_detections = (*output_packets)[kPalmDetectionsName].Get>(); @@ -186,20 +207,30 @@ INSTANTIATE_TEST_SUITE_P( Values(TestParams{.test_name = "DetectOneHand", .hand_detection_model_name = kPalmDetectionModel, .test_image_name = kTestRightHandsImage, + .rotation = 0, .num_hands = 1, .expected_result = GetExpectedHandDetectorResult(kOneHandResultFile)}, TestParams{.test_name = "DetectTwoHands", .hand_detection_model_name = kPalmDetectionModel, .test_image_name = kTestRightHandsImage, + .rotation = 0, .num_hands = 2, .expected_result = - GetExpectedHandDetectorResult(kTwoHandsResultFile)}), + GetExpectedHandDetectorResult(kTwoHandsResultFile)}, + TestParams{.test_name = "DetectOneHandWithRotation", + .hand_detection_model_name = kPalmDetectionModel, + .test_image_name = kTestRightHandsRotatedImage, + .rotation = M_PI / 2.0f, + .num_hands = 1, + .expected_result = GetExpectedHandDetectorResult( + kOneHandRotatedResultFile)}), [](const TestParamInfo& info) { return info.param.test_name; }); } // namespace +} // namespace hand_detector } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc b/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc deleted file mode 100644 index 262fb2c75..000000000 --- a/mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h" - -#include "mediapipe/util/tflite/operations/max_pool_argmax.h" -#include "mediapipe/util/tflite/operations/max_unpooling.h" -#include "mediapipe/util/tflite/operations/transpose_conv_bias.h" - -namespace mediapipe { -namespace tasks { -namespace vision { -HandDetectorOpResolver::HandDetectorOpResolver() { - AddCustom("MaxPoolingWithArgmax2D", - mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); - AddCustom("MaxUnpooling2D", - mediapipe::tflite_operations::RegisterMaxUnpooling2D()); - AddCustom("Convolution2DTransposeBias", - mediapipe::tflite_operations::RegisterConvolution2DTransposeBias()); -} -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD index 2d22aab10..77f3b2649 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/BUILD @@ -21,8 +21,8 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_detector_options_proto", - srcs = ["hand_detector_options.proto"], + name = "hand_detector_graph_options_proto", + srcs = ["hand_detector_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", diff --git a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto similarity index 75% rename from mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto rename to mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto index ae22c7991..a009f2365 100644 --- a/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.proto +++ b/mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto @@ -20,25 +20,21 @@ package mediapipe.tasks.vision.hand_detector.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -option java_package = "com.google.mediapipe.tasks.vision.handdetector"; -option java_outer_classname = "HandDetectorOptionsProto"; +option java_package = "com.google.mediapipe.tasks.vision.handdetector.proto"; +option java_outer_classname = "HandDetectorGraphOptionsProto"; -message HandDetectorOptions { +message HandDetectorGraphOptions { extend mediapipe.CalculatorOptions { - optional HandDetectorOptions ext = 464864288; + optional HandDetectorGraphOptions ext = 464864288; } // Base options for configuring Task library, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; - // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered // successfully detecting a hand in the image. - optional float min_detection_confidence = 3 [default = 0.5]; + optional float min_detection_confidence = 2 [default = 0.5]; // The maximum number of hands output by the detector. - optional int32 num_hands = 4; + optional int32 num_hands = 3; } diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD deleted file mode 100644 index bb5b86212..000000000 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/BUILD +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2022 The MediaPipe Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package(default_visibility = [ - "//mediapipe/tasks:internal", -]) - -licenses(["notice"]) - -cc_library( - name = "handedness_util", - srcs = ["handedness_util.cc"], - hdrs = ["handedness_util.h"], - deps = [ - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/port:ret_check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -cc_test( - name = "handedness_util_test", - srcs = ["handedness_util_test.cc"], - deps = [ - ":handedness_util", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/port:gtest_main", - ], -) - -cc_library( - name = "hand_gesture_recognizer_subgraph", - srcs = ["hand_gesture_recognizer_subgraph.cc"], - deps = [ - "//mediapipe/calculators/core:concatenate_vector_calculator", - "//mediapipe/calculators/tensor:tensor_converter_calculator", - "//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto", - "//mediapipe/framework/api2:builder", - "//mediapipe/framework/api2:port", - "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/framework/formats:landmark_cc_proto", - "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:tensor", - "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components:image_preprocessing", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/core:model_resources", - "//mediapipe/tasks/cc/core:model_task_graph", - "//mediapipe/tasks/cc/core:utils", - "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:handedness_to_matrix_calculator", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/calculators:landmarks_to_matrix_calculator", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:hand_gesture_recognizer_subgraph_options_cc_proto", - "//mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto:landmarks_to_matrix_calculator_cc_proto", - "//mediapipe/tasks/cc/vision/hand_landmarker:hand_landmarker_subgraph", - "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", - "//mediapipe/tasks/metadata:metadata_schema_cc", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], - alwayslink = 1, -) diff --git a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc b/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc deleted file mode 100644 index e124d3410..000000000 --- a/mediapipe/tasks/cc/vision/hand_gesture_recognizer/hand_gesture_recognizer_subgraph.cc +++ /dev/null @@ -1,368 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" -#include "mediapipe/framework/api2/builder.h" -#include "mediapipe/framework/api2/port.h" -#include "mediapipe/framework/formats/classification.pb.h" -#include "mediapipe/framework/formats/landmark.pb.h" -#include "mediapipe/framework/formats/matrix.h" -#include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/core/model_resources.h" -#include "mediapipe/tasks/cc/core/model_task_graph.h" -#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" -#include "mediapipe/tasks/cc/core/utils.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/hand_gesture_recognizer_subgraph_options.pb.h" -#include "mediapipe/tasks/cc/vision/hand_gesture_recognizer/proto/landmarks_to_matrix_calculator.pb.h" -#include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" -#include "mediapipe/tasks/metadata/metadata_schema_generated.h" - -namespace mediapipe { -namespace tasks { -namespace vision { - -namespace { - -using ::mediapipe::api2::Input; -using ::mediapipe::api2::Output; -using ::mediapipe::api2::builder::Graph; -using ::mediapipe::api2::builder::Source; -using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: - HandGestureRecognizerSubgraphOptions; -using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; - -constexpr char kHandednessTag[] = "HANDEDNESS"; -constexpr char kLandmarksTag[] = "LANDMARKS"; -constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; -constexpr char kImageSizeTag[] = "IMAGE_SIZE"; -constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS"; -constexpr char kHandGesturesTag[] = "HAND_GESTURES"; -constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX"; -constexpr char kTensorsTag[] = "TENSORS"; -constexpr char kHandednessMatrixTag[] = "HANDEDNESS_MATRIX"; -constexpr char kCloneTag[] = "CLONE"; -constexpr char kItemTag[] = "ITEM"; -constexpr char kVectorTag[] = "VECTOR"; -constexpr char kIndexTag[] = "INDEX"; -constexpr char kIterableTag[] = "ITERABLE"; -constexpr char kBatchEndTag[] = "BATCH_END"; - -absl::Status SanityCheckOptions( - const HandGestureRecognizerSubgraphOptions& options) { - if (options.min_tracking_confidence() < 0 || - options.min_tracking_confidence() > 1) { - return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, - "Invalid `min_tracking_confidence` option: " - "value must be in the range [0.0, 1.0]", - MediaPipeTasksStatus::kInvalidArgumentError); - } - return absl::OkStatus(); -} - -Source> ConvertMatrixToTensor(Source matrix, - Graph& graph) { - auto& node = graph.AddNode("TensorConverterCalculator"); - matrix >> node.In("MATRIX"); - return node[Output>{"TENSORS"}]; -} - -} // namespace - -// A "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" performs -// single hand gesture recognition. This graph is used as a building block for -// mediapipe.tasks.vision.HandGestureRecognizerGraph. -// -// Inputs: -// HANDEDNESS - ClassificationList -// Classification of handedness. -// LANDMARKS - NormalizedLandmarkList -// Detected hand landmarks in normalized image coordinates. -// WORLD_LANDMARKS - LandmarkList -// Detected hand landmarks in world coordinates. -// IMAGE_SIZE - std::pair -// The size of image from which the landmarks detected from. -// -// Outputs: -// HAND_GESTURES - ClassificationResult -// Recognized hand gestures with sorted order such that the winning label is -// the first item in the list. -// -// -// Example: -// node { -// calculator: "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph" -// input_stream: "HANDEDNESS:handedness" -// input_stream: "LANDMARKS:landmarks" -// input_stream: "WORLD_LANDMARKS:world_landmarks" -// input_stream: "IMAGE_SIZE:image_size" -// output_stream: "HAND_GESTURES:hand_gestures" -// options { -// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraphOptions.ext] -// { -// base_options { -// model_asset { -// file_name: "hand_gesture.tflite" -// } -// } -// } -// } -// } -class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph { - public: - absl::StatusOr GetConfig( - SubgraphContext* sc) override { - ASSIGN_OR_RETURN( - const auto* model_resources, - CreateModelResources(sc)); - Graph graph; - ASSIGN_OR_RETURN( - auto hand_gestures, - BuildHandGestureRecognizerGraph( - sc->Options(), - *model_resources, graph[Input(kHandednessTag)], - graph[Input(kLandmarksTag)], - graph[Input(kWorldLandmarksTag)], - graph[Input>(kImageSizeTag)], graph)); - hand_gestures >> graph[Output(kHandGesturesTag)]; - return graph.GetConfig(); - } - - private: - absl::StatusOr> BuildHandGestureRecognizerGraph( - const HandGestureRecognizerSubgraphOptions& graph_options, - const core::ModelResources& model_resources, - Source handedness, - Source hand_landmarks, - Source hand_world_landmarks, - Source> image_size, Graph& graph) { - MP_RETURN_IF_ERROR(SanityCheckOptions(graph_options)); - - // Converts the ClassificationList to a matrix. - auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator"); - handedness >> handedness_to_matrix.In(kHandednessTag); - auto handedness_matrix = - handedness_to_matrix[Output(kHandednessMatrixTag)]; - - // Converts the handedness matrix to a tensor for the inference - // calculator. - auto handedness_tensors = ConvertMatrixToTensor(handedness_matrix, graph); - - // Converts the screen landmarks to a matrix. - LandmarksToMatrixCalculatorOptions landmarks_options; - landmarks_options.set_object_normalization(true); - landmarks_options.set_object_normalization_origin_offset(0); - auto& hand_landmarks_to_matrix = - graph.AddNode("LandmarksToMatrixCalculator"); - hand_landmarks_to_matrix.GetOptions() = - landmarks_options; - hand_landmarks >> hand_landmarks_to_matrix.In(kLandmarksTag); - image_size >> hand_landmarks_to_matrix.In(kImageSizeTag); - auto hand_landmarks_matrix = - hand_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; - - // Converts the landmarks matrix to a tensor for the inference calculator. - auto hand_landmarks_tensor = - ConvertMatrixToTensor(hand_landmarks_matrix, graph); - - // Converts the world landmarks to a matrix. - auto& hand_world_landmarks_to_matrix = - graph.AddNode("LandmarksToMatrixCalculator"); - hand_world_landmarks_to_matrix - .GetOptions() = landmarks_options; - hand_world_landmarks >> - hand_world_landmarks_to_matrix.In(kWorldLandmarksTag); - image_size >> hand_world_landmarks_to_matrix.In(kImageSizeTag); - auto hand_world_landmarks_matrix = - hand_world_landmarks_to_matrix[Output(kLandmarksMatrixTag)]; - - // Converts the world landmarks matrix to a tensor for the inference - // calculator. - auto hand_world_landmarks_tensor = - ConvertMatrixToTensor(hand_world_landmarks_matrix, graph); - - // Converts a tensor into a vector of tensors for the inference - // calculator. - auto& concatenate_tensor_vector = - graph.AddNode("ConcatenateTensorVectorCalculator"); - hand_landmarks_tensor >> concatenate_tensor_vector.In(0); - handedness_tensors >> concatenate_tensor_vector.In(1); - hand_world_landmarks_tensor >> concatenate_tensor_vector.In(2); - auto concatenated_tensors = concatenate_tensor_vector.Out(""); - - // Inference for static hand gesture recognition. - auto& inference = AddInference( - model_resources, graph_options.base_options().acceleration(), graph); - concatenated_tensors >> inference.In(kTensorsTag); - auto inference_output_tensors = inference.Out(kTensorsTag); - - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, graph_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); - inference_output_tensors >> postprocessing.In(kTensorsTag); - auto classification_result = - postprocessing[Output("CLASSIFICATION_RESULT")]; - - return classification_result; - } -}; - -REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::SingleHandGestureRecognizerSubgraph); - -// A "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" performs multi -// hand gesture recognition. This graph is used as a building block for -// mediapipe.tasks.vision.HandGestureRecognizerGraph. -// -// Inputs: -// HANDEDNESS - std::vector -// A vector of Classification of handedness. -// LANDMARKS - std::vector -// A vector hand landmarks in normalized image coordinates. -// WORLD_LANDMARKS - std::vector -// A vector hand landmarks in world coordinates. -// IMAGE_SIZE - std::pair -// The size of image from which the landmarks detected from. -// HAND_TRACKING_IDS - std::vector -// A vector of the tracking ids of the hands. The tracking id is the vector -// index corresponding to the same hand if the graph runs multiple times. -// -// Outputs: -// HAND_GESTURES - std::vector -// A vector of recognized hand gestures. Each vector element is the -// ClassificationResult of the hand in input vector. -// -// -// Example: -// node { -// calculator: "mediapipe.tasks.vision.HandGestureRecognizerSubgraph" -// input_stream: "HANDEDNESS:handedness" -// input_stream: "LANDMARKS:landmarks" -// input_stream: "WORLD_LANDMARKS:world_landmarks" -// input_stream: "IMAGE_SIZE:image_size" -// input_stream: "HAND_TRACKING_IDS:hand_tracking_ids" -// output_stream: "HAND_GESTURES:hand_gestures" -// options { -// [mediapipe.tasks.vision.hand_gesture_recognizer.proto.HandGestureRecognizerSubgraph.ext] -// { -// base_options { -// model_asset { -// file_name: "hand_gesture.tflite" -// } -// } -// } -// } -// } -class HandGestureRecognizerSubgraph : public core::ModelTaskGraph { - public: - absl::StatusOr GetConfig( - SubgraphContext* sc) override { - Graph graph; - ASSIGN_OR_RETURN( - auto multi_hand_gestures, - BuildMultiHandGestureRecognizerSubraph( - sc->Options(), - graph[Input>(kHandednessTag)], - graph[Input>(kLandmarksTag)], - graph[Input>(kWorldLandmarksTag)], - graph[Input>(kImageSizeTag)], - graph[Input>(kHandTrackingIdsTag)], graph)); - multi_hand_gestures >> - graph[Output>(kHandGesturesTag)]; - return graph.GetConfig(); - } - - private: - absl::StatusOr>> - BuildMultiHandGestureRecognizerSubraph( - const HandGestureRecognizerSubgraphOptions& graph_options, - Source> multi_handedness, - Source> multi_hand_landmarks, - Source> multi_hand_world_landmarks, - Source> image_size, - Source> multi_hand_tracking_ids, Graph& graph) { - auto& begin_loop_int = graph.AddNode("BeginLoopIntCalculator"); - image_size >> begin_loop_int.In(kCloneTag)[0]; - multi_handedness >> begin_loop_int.In(kCloneTag)[1]; - multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[2]; - multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[3]; - multi_hand_tracking_ids >> begin_loop_int.In(kIterableTag); - auto image_size_clone = begin_loop_int.Out(kCloneTag)[0]; - auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[1]; - auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[2]; - auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[3]; - auto hand_tracking_id = begin_loop_int.Out(kItemTag); - auto batch_end = begin_loop_int.Out(kBatchEndTag); - - auto& get_handedness_at_index = - graph.AddNode("GetClassificationListVectorItemCalculator"); - multi_handedness_clone >> get_handedness_at_index.In(kVectorTag); - hand_tracking_id >> get_handedness_at_index.In(kIndexTag); - auto handedness = get_handedness_at_index.Out(kItemTag); - - auto& get_landmarks_at_index = - graph.AddNode("GetNormalizedLandmarkListVectorItemCalculator"); - multi_hand_landmarks_clone >> get_landmarks_at_index.In(kVectorTag); - hand_tracking_id >> get_landmarks_at_index.In(kIndexTag); - auto hand_landmarks = get_landmarks_at_index.Out(kItemTag); - - auto& get_world_landmarks_at_index = - graph.AddNode("GetLandmarkListVectorItemCalculator"); - multi_hand_world_landmarks_clone >> - get_world_landmarks_at_index.In(kVectorTag); - hand_tracking_id >> get_world_landmarks_at_index.In(kIndexTag); - auto hand_world_landmarks = get_world_landmarks_at_index.Out(kItemTag); - - auto& hand_gesture_recognizer_subgraph = graph.AddNode( - "mediapipe.tasks.vision.SingleHandGestureRecognizerSubgraph"); - hand_gesture_recognizer_subgraph - .GetOptions() - .CopyFrom(graph_options); - handedness >> hand_gesture_recognizer_subgraph.In(kHandednessTag); - hand_landmarks >> hand_gesture_recognizer_subgraph.In(kLandmarksTag); - hand_world_landmarks >> - hand_gesture_recognizer_subgraph.In(kWorldLandmarksTag); - image_size_clone >> hand_gesture_recognizer_subgraph.In(kImageSizeTag); - auto hand_gestures = hand_gesture_recognizer_subgraph.Out(kHandGesturesTag); - - auto& end_loop_classification_results = - graph.AddNode("mediapipe.tasks.EndLoopClassificationResultCalculator"); - batch_end >> end_loop_classification_results.In(kBatchEndTag); - hand_gestures >> end_loop_classification_results.In(kItemTag); - auto multi_hand_gestures = end_loop_classification_results - [Output>(kIterableTag)]; - - return multi_hand_gestures; - } -}; - -REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::HandGestureRecognizerSubgraph); - -} // namespace vision -} // namespace tasks -} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD index 653976b96..9090fc7b3 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/BUILD @@ -19,10 +19,10 @@ package(default_visibility = [ licenses(["notice"]) cc_library( - name = "hand_landmarker_subgraph", - srcs = ["hand_landmarker_subgraph.cc"], + name = "hand_landmarks_detector_graph", + srcs = ["hand_landmarks_detector_graph.cc"], deps = [ - "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_subgraph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "//mediapipe/calculators/core:split_vector_calculator", @@ -51,6 +51,7 @@ cc_library( # TODO: move calculators in modules/hand_landmark/calculators to tasks dir. "//mediapipe/modules/hand_landmark/calculators:hand_landmarks_to_rect_calculator", "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/utils:gate", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", @@ -66,3 +67,47 @@ cc_library( ) # TODO: Enable this test + +cc_library( + name = "hand_landmarker_graph", + srcs = ["hand_landmarker_graph.cc"], + deps = [ + ":hand_landmarks_detector_graph", + "//mediapipe/calculators/core:begin_loop_calculator", + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/core:gate_calculator", + "//mediapipe/calculators/core:gate_calculator_cc_proto", + "//mediapipe/calculators/core:pass_through_calculator", + "//mediapipe/calculators/core:previous_loopback_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator", + "//mediapipe/calculators/util:collection_has_min_size_calculator_cc_proto", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:status", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/components/utils:gate", + "//mediapipe/tasks/cc/core:model_asset_bundle_resources", + "//mediapipe/tasks/cc/core:model_resources_cache", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/metadata/utils:zip_utils", + "//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_landmarks_deduplication_calculator", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_cc_proto", + ], + alwayslink = 1, +) + +# TODO: Enable this test diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD new file mode 100644 index 000000000..f45681fb3 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD @@ -0,0 +1,70 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "hand_association_calculator_proto", + srcs = ["hand_association_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + ], +) + +cc_library( + name = "hand_association_calculator", + srcs = ["hand_association_calculator.cc"], + deps = [ + ":hand_association_calculator_cc_proto", + "//mediapipe/calculators/util:association_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:rectangle", + "//mediapipe/framework/port:status", + "//mediapipe/util:rectangle_util", + ], + alwayslink = 1, +) + +cc_library( + name = "hand_landmarks_deduplication_calculator", + srcs = ["hand_landmarks_deduplication_calculator.cc"], + hdrs = ["hand_landmarks_deduplication_calculator.h"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:rect", + "//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder", + "//mediapipe/tasks/cc/vision/utils:landmarks_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc new file mode 100644 index 000000000..b6df80588 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -0,0 +1,125 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/rectangle.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" +#include "mediapipe/util/rectangle_util.h" + +namespace mediapipe::api2 { + +// HandAssociationCalculator accepts multiple inputs of vectors of +// NormalizedRect. The output is a vector of NormalizedRect that contains +// rects from the input vectors that don't overlap with each other. When two +// rects overlap, the rect that comes in from an earlier input stream is +// kept in the output. If a rect has no ID (i.e. from detection stream), +// then a unique rect ID is assigned for it. + +// The rects in multiple input streams are effectively flattened to a single +// list. For example: +// Stream1 : rect 1, rect 2 +// Stream2: rect 3, rect 4 +// Stream3: rect 5, rect 6 +// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6 +// In the flattened list, if a rect with a higher index overlaps with a rect a +// lower index, beyond a specified IOU threshold, the rect with the lower +// index will be in the output, and the rect with higher index will be +// discarded. +// TODO: Upgrade this to latest API for calculators +class HandAssociationCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + // Initialize input and output streams. + for (auto& input_stream : cc->Inputs()) { + input_stream.Set>(); + } + cc->Outputs().Index(0).Set>(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + + options_ = cc->Options(); + CHECK_GT(options_.min_similarity_threshold(), 0.0); + CHECK_LE(options_.min_similarity_threshold(), 1.0); + + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + ASSIGN_OR_RETURN(auto result, GetNonOverlappingElements(cc)); + + auto output = + std::make_unique>(std::move(result)); + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } + + private: + HandAssociationCalculatorOptions options_; + + // Return a list of non-overlapping elements from all input streams, with + // decreasing order of priority based on input stream index and indices + // within an input stream. + absl::StatusOr> GetNonOverlappingElements( + CalculatorContext* cc) { + std::vector result; + + for (const auto& input_stream : cc->Inputs()) { + if (input_stream.IsEmpty()) { + continue; + } + + for (auto rect : input_stream.Get>()) { + ASSIGN_OR_RETURN( + bool is_overlapping, + mediapipe::DoesRectOverlap(rect, result, + options_.min_similarity_threshold())); + if (!is_overlapping) { + if (!rect.has_rect_id()) { + rect.set_rect_id(GetNextRectId()); + } + result.push_back(rect); + } + } + } + + return result; + } + + private: + // Each NormalizedRect processed by the calculator will be assigned + // an unique id, if it does not already have an ID. The starting ID will be 1. + // Note: This rect_id_ is local to an instance of this calculator. And it is + // expected that the hand tracking graph to have only one instance of + // this association calculator. + int64 rect_id_ = 1; + + inline int GetNextRectId() { return rect_id_++; } +}; + +MEDIAPIPE_REGISTER_NODE(HandAssociationCalculator); + +} // namespace mediapipe::api2 diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto new file mode 100644 index 000000000..e7229b4a2 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.proto @@ -0,0 +1,28 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message HandAssociationCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional HandAssociationCalculatorOptions ext = 408244367; + } + + optional float min_similarity_threshold = 1 [default = 1.0]; +} diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc new file mode 100644 index 000000000..cb3130854 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -0,0 +1,302 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" + +namespace mediapipe { +namespace { + +class HandAssociationCalculatorTest : public testing::Test { + protected: + HandAssociationCalculatorTest() { + // 0.4 ================ + // | | | | + // 0.3 ===================== | NR2 | | + // | | | NR1 | | | NR4 | + // 0.2 | NR0 | =========== ================ + // | | | | | | + // 0.1 =====|=============== | + // | NR3 | | | + // 0.0 ================ | + // | NR5 | + // -0.1 =========== + // 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 + + // NormalizedRect nr_0. + nr_0_.set_x_center(0.2); + nr_0_.set_y_center(0.2); + nr_0_.set_width(0.2); + nr_0_.set_height(0.2); + + // NormalizedRect nr_1. + nr_1_.set_x_center(0.4); + nr_1_.set_y_center(0.2); + nr_1_.set_width(0.2); + nr_1_.set_height(0.2); + + // NormalizedRect nr_2. + nr_2_.set_x_center(1.0); + nr_2_.set_y_center(0.3); + nr_2_.set_width(0.2); + nr_2_.set_height(0.2); + + // NormalizedRect nr_3. + nr_3_.set_x_center(0.35); + nr_3_.set_y_center(0.15); + nr_3_.set_width(0.3); + nr_3_.set_height(0.3); + + // NormalizedRect nr_4. + nr_4_.set_x_center(1.1); + nr_4_.set_y_center(0.3); + nr_4_.set_width(0.2); + nr_4_.set_height(0.2); + + // NormalizedRect nr_5. + nr_5_.set_x_center(0.5); + nr_5_.set_y_center(0.05); + nr_5_.set_width(0.3); + nr_5_.set_height(0.3); + } + + NormalizedRect nr_0_, nr_1_, nr_2_, nr_3_, nr_4_, nr_5_; +}; + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_0, nr_1, nr_2. + auto input_vec_0 = std::make_unique>(); + input_vec_0->push_back(nr_0_); + input_vec_0->push_back(nr_1_); + input_vec_0->push_back(nr_2_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_3, nr_4. + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_3_); + input_vec_1->push_back(nr_4_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + // Input Stream 2: nr_5. + auto input_vec_2 = std::make_unique>(); + input_vec_2->push_back(nr_5_); + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(input_vec_2.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_0 is added 1st. + // nr_1 is added because it does not overlap with nr_0. + // nr_2 is added because it does not overlap with nr_0 or nr_1. + // nr_3 is NOT added because it overlaps with nr_0. + // nr_4 is NOT added because it overlaps with nr_2. + // nr_5 is NOT added because it overlaps with nr_1. + EXPECT_EQ(3, assoc_rects.size()); + + // Check that IDs are filled in and contents match. + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), 2); + assoc_rects[1].clear_rect_id(); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 3); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_0, nr_1. Tracked hands. + auto input_vec_0 = std::make_unique>(); + // Setting ID to a negative number for test only, since newly generated + // ID by HandAssociationCalculator are positive numbers. + nr_0_.set_rect_id(-2); + input_vec_0->push_back(nr_0_); + nr_1_.set_rect_id(-1); + input_vec_0->push_back(nr_1_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_2, nr_3. Newly detected palms. + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_2_); + input_vec_1->push_back(nr_3_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_0 is added 1st. + // nr_1 is added because it does not overlap with nr_0. + // nr_2 is added because it does not overlap with nr_0 or nr_1. + // nr_3 is NOT added because it overlaps with nr_0. + EXPECT_EQ(3, assoc_rects.size()); + + // Check that IDs are filled in and contents match. + EXPECT_EQ(assoc_rects[0].rect_id(), -2); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), -1); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 1); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec_0" + input_stream: "input_vec_1" + input_stream: "input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Input Stream 0: nr_5. + auto input_vec_0 = std::make_unique>(); + input_vec_0->push_back(nr_5_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_4, nr_3 + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_4_); + input_vec_1->push_back(nr_3_); + runner.MutableInputs()->Index(1).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + // Input Stream 2: nr_2, nr_1, nr_0. + auto input_vec_2 = std::make_unique>(); + input_vec_2->push_back(nr_2_); + input_vec_2->push_back(nr_1_); + input_vec_2->push_back(nr_0_); + runner.MutableInputs()->Index(2).packets.push_back( + Adopt(input_vec_2.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_5 is added 1st. + // nr_4 is added because it does not overlap with nr_5. + // nr_3 is NOT added because it overlaps with nr_5. + // nr_2 is NOT added because it overlaps with nr_4. + // nr_1 is NOT added because it overlaps with nr_5. + // nr_0 is added because it does not overlap with nr_5 or nr_4. + EXPECT_EQ(3, assoc_rects.size()); + + // Outputs are in same order as inputs, and IDs are filled in. + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_)); + + EXPECT_EQ(assoc_rects[1].rect_id(), 2); + assoc_rects[1].clear_rect_id(); + EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_)); + + EXPECT_EQ(assoc_rects[2].rect_id(), 3); + assoc_rects[2].clear_rect_id(); + EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_)); +} + +TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "input_vec" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); + + // Just one input stream : nr_3, nr_5. + auto input_vec = std::make_unique>(); + input_vec->push_back(nr_3_); + input_vec->push_back(nr_5_); + runner.MutableInputs()->Index(0).packets.push_back( + Adopt(input_vec.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_3 is added 1st. + // nr_5 is NOT added because it overlaps with nr_3. + EXPECT_EQ(1, assoc_rects.size()); + + EXPECT_EQ(assoc_rects[0].rect_id(), 1); + assoc_rects[0].clear_rect_id(); + EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_)); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc new file mode 100644 index 000000000..5a5baa50e --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.cc @@ -0,0 +1,310 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h" +#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" + +namespace mediapipe::api2 { +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::vision::utils::CalculateIOU; +using ::mediapipe::tasks::vision::utils::DuplicatesFinder; + +float Distance(const NormalizedLandmark& lm_a, const NormalizedLandmark& lm_b, + int width, int height) { + return std::sqrt(std::pow((lm_a.x() - lm_b.x()) * width, 2) + + std::pow((lm_a.y() - lm_b.y()) * height, 2)); +} + +absl::StatusOr> Distances(const NormalizedLandmarkList& a, + const NormalizedLandmarkList& b, + int width, int height) { + const int num = a.landmark_size(); + RET_CHECK_EQ(b.landmark_size(), num); + std::vector distances; + distances.reserve(num); + for (int i = 0; i < num; ++i) { + const NormalizedLandmark& lm_a = a.landmark(i); + const NormalizedLandmark& lm_b = b.landmark(i); + distances.push_back(Distance(lm_a, lm_b, width, height)); + } + return distances; +} + +// Calculates a baseline distance of a hand that can be used as a relative +// measure when calculating hand to hand similarity. +// +// Calculated as maximum of distances: 0->5, 5->17, 17->0, where 0, 5, 17 key +// points are depicted below: +// +// /Middle/ +// | +// /Index/ | /Ring/ +// | | | /Pinky/ +// V V V | +// V +// [8] [12] [16] +// | | | [20] +// | | | | +// /Thumb/ | | | | +// | [7] [11] [15] [19] +// V | | | | +// | | | | +// [4] | | | | +// | [6] [10] [14] [18] +// | | | | | +// | | | | | +// [3] | | | | +// | [5]----[9]---[13]---[17] +// . | | +// \ . | +// \ / | +// [2] | +// \ | +// \ | +// \ | +// [1] . +// \ / +// \ / +// ._____[0]_____. +// +// ^ +// | +// /Wrist/ +absl::StatusOr HandBaselineDistance( + const NormalizedLandmarkList& landmarks, int width, int height) { + RET_CHECK_EQ(landmarks.landmark_size(), 21); // Num of hand landmarks. + constexpr int kWrist = 0; + constexpr int kIndexFingerMcp = 5; + constexpr int kPinkyMcp = 17; + float distance = Distance(landmarks.landmark(kWrist), + landmarks.landmark(kIndexFingerMcp), width, height); + distance = std::max(distance, + Distance(landmarks.landmark(kIndexFingerMcp), + landmarks.landmark(kPinkyMcp), width, height)); + distance = + std::max(distance, Distance(landmarks.landmark(kPinkyMcp), + landmarks.landmark(kWrist), width, height)); + return distance; +} + +Rect CalculateBound(const NormalizedLandmarkList& list) { + constexpr float kMinInitialValue = std::numeric_limits::max(); + constexpr float kMaxInitialValue = std::numeric_limits::lowest(); + + // Compute min and max values on landmarks (they will form + // bounding box) + float bounding_box_left = kMinInitialValue; + float bounding_box_top = kMinInitialValue; + float bounding_box_right = kMaxInitialValue; + float bounding_box_bottom = kMaxInitialValue; + for (const auto& landmark : list.landmark()) { + bounding_box_left = std::min(bounding_box_left, landmark.x()); + bounding_box_top = std::min(bounding_box_top, landmark.y()); + bounding_box_right = std::max(bounding_box_right, landmark.x()); + bounding_box_bottom = std::max(bounding_box_bottom, landmark.y()); + } + + // Populate normalized non rotated face bounding box + return {.left = bounding_box_left, + .top = bounding_box_top, + .right = bounding_box_right, + .bottom = bounding_box_bottom}; +} + +// Uses IoU and distance of some corresponding hand landmarks to detect +// duplicate / similar hands. IoU, distance thresholds, number of landmarks to +// match are found experimentally. Evaluated: +// - manually comparing side by side, before and after deduplication applied +// - generating gesture dataset, and checking select frames in baseline and +// "deduplicated" dataset +// - by confirming gesture training is better with use of deduplication using +// selected thresholds +class HandDuplicatesFinder : public DuplicatesFinder { + public: + explicit HandDuplicatesFinder(bool start_from_the_end) + : start_from_the_end_(start_from_the_end) {} + + absl::StatusOr> FindDuplicates( + const std::vector& multi_landmarks, + int input_width, int input_height) override { + absl::flat_hash_set retained_indices; + absl::flat_hash_set suppressed_indices; + + const int num = multi_landmarks.size(); + std::vector baseline_distances; + baseline_distances.reserve(num); + std::vector bounds; + bounds.reserve(num); + for (const NormalizedLandmarkList& list : multi_landmarks) { + ASSIGN_OR_RETURN(const float baseline_distance, + HandBaselineDistance(list, input_width, input_height)); + baseline_distances.push_back(baseline_distance); + bounds.push_back(CalculateBound(list)); + } + + for (int index = 0; index < num; ++index) { + const int i = start_from_the_end_ ? num - index - 1 : index; + const float stable_distance_i = baseline_distances[i]; + bool suppressed = false; + for (int j : retained_indices) { + const float stable_distance_j = baseline_distances[j]; + + constexpr float kAllowedBaselineDistanceRatio = 0.2f; + const float distance_threshold = + std::max(stable_distance_i, stable_distance_j) * + kAllowedBaselineDistanceRatio; + + ASSIGN_OR_RETURN(const std::vector distances, + Distances(multi_landmarks[i], multi_landmarks[j], + input_width, input_height)); + const int num_matched_landmarks = absl::c_count_if( + distances, + [&](float distance) { return distance < distance_threshold; }); + + const float iou = CalculateIOU(bounds[i], bounds[j]); + + constexpr int kNumMatchedLandmarksToSuppressHand = 10; // out of 21 + constexpr float kMinIouThresholdToSuppressHand = 0.2f; + if (num_matched_landmarks >= kNumMatchedLandmarksToSuppressHand && + iou > kMinIouThresholdToSuppressHand) { + suppressed = true; + break; + } + } + + if (suppressed) { + suppressed_indices.insert(i); + } else { + retained_indices.insert(i); + } + } + return suppressed_indices; + } + + private: + const bool start_from_the_end_; +}; + +template +absl::StatusOr> +VerifyNumAndMaybeInitOutput(const InputPortT& port, CalculatorContext* cc, + int num_expected_size) { + absl::optional output; + if (port(cc).IsConnected() && !port(cc).IsEmpty()) { + RET_CHECK_EQ(port(cc).Get().size(), num_expected_size); + typename InputPortT::PayloadT result; + return {{result}}; + } + return {absl::nullopt}; +} +} // namespace + +std::unique_ptr CreateHandDuplicatesFinder( + bool start_from_the_end) { + return absl::make_unique(start_from_the_end); +} + +absl::Status HandLandmarksDeduplicationCalculator::Process( + mediapipe::CalculatorContext* cc) { + if (kInLandmarks(cc).IsEmpty()) return absl::OkStatus(); + if (kInSize(cc).IsEmpty()) return absl::OkStatus(); + + const std::vector& in_landmarks = *kInLandmarks(cc); + const std::pair& image_size = *kInSize(cc); + + std::unique_ptr duplicates_finder = + CreateHandDuplicatesFinder(/*start_from_the_end=*/false); + ASSIGN_OR_RETURN(absl::flat_hash_set indices_to_remove, + duplicates_finder->FindDuplicates( + in_landmarks, image_size.first, image_size.second)); + + if (indices_to_remove.empty()) { + kOutLandmarks(cc).Send(kInLandmarks(cc)); + kOutRois(cc).Send(kInRois(cc)); + kOutWorldLandmarks(cc).Send(kInWorldLandmarks(cc)); + kOutClassifications(cc).Send(kInClassifications(cc)); + } else { + std::vector out_landmarks; + const int num = in_landmarks.size(); + + ASSIGN_OR_RETURN(absl::optional> out_rois, + VerifyNumAndMaybeInitOutput(kInRois, cc, num)); + ASSIGN_OR_RETURN( + absl::optional> out_world_landmarks, + VerifyNumAndMaybeInitOutput(kInWorldLandmarks, cc, num)); + ASSIGN_OR_RETURN( + absl::optional> out_classifications, + VerifyNumAndMaybeInitOutput(kInClassifications, cc, num)); + + for (int i = 0; i < num; ++i) { + if (indices_to_remove.find(i) != indices_to_remove.end()) continue; + + out_landmarks.push_back(in_landmarks[i]); + if (out_rois) { + out_rois->push_back(kInRois(cc).Get()[i]); + } + if (out_world_landmarks) { + out_world_landmarks->push_back(kInWorldLandmarks(cc).Get()[i]); + } + if (out_classifications) { + out_classifications->push_back(kInClassifications(cc).Get()[i]); + } + } + + if (!out_landmarks.empty()) { + kOutLandmarks(cc).Send(std::move(out_landmarks)); + } + if (out_rois && !out_rois->empty()) { + kOutRois(cc).Send(std::move(out_rois.value())); + } + if (out_world_landmarks && !out_world_landmarks->empty()) { + kOutWorldLandmarks(cc).Send(std::move(out_world_landmarks.value())); + } + if (out_classifications && !out_classifications->empty()) { + kOutClassifications(cc).Send(std::move(out_classifications.value())); + } + } + return absl::OkStatus(); +} +MEDIAPIPE_REGISTER_NODE(HandLandmarksDeduplicationCalculator); + +} // namespace mediapipe::api2 diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h new file mode 100644 index 000000000..d7b435487 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_landmarks_deduplication_calculator.h @@ -0,0 +1,97 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_ +#define MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_ + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h" + +namespace mediapipe::api2 { + +// Create a DuplicatesFinder dedicated for finding hand duplications. +std::unique_ptr +CreateHandDuplicatesFinder(bool start_from_the_end = false); + +// Filter duplicate hand landmarks by finding the overlapped hands. +// Inputs: +// MULTI_LANDMARKS - std::vector +// The hand landmarks to be filtered. +// MULTI_ROIS - std::vector +// The regions where each encloses the landmarks of a single hand. +// MULTI_WORLD_LANDMARKS - std::vector +// The hand landmarks to be filtered in world coordinates. +// MULTI_CLASSIFICATIONS - std::vector +// The handedness of hands. +// IMAGE_SIZE - std::pair +// The size of the image which the hand landmarks are detected on. +// +// Outputs: +// MULTI_LANDMARKS - std::vector +// The hand landmarks with duplication removed. +// MULTI_ROIS - std::vector +// The regions where each encloses the landmarks of a single hand with +// duplicate hands removed. +// MULTI_WORLD_LANDMARKS - std::vector +// The hand landmarks with duplication removed in world coordinates. +// MULTI_CLASSIFICATIONS - std::vector +// The handedness of hands with duplicate hands removed. +// +// Example: +// node { +// calculator: "HandLandmarksDeduplicationCalculator" +// input_stream: "MULTI_LANDMARKS:landmarks_in" +// input_stream: "MULTI_ROIS:rois_in" +// input_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_in" +// input_stream: "MULTI_CLASSIFICATIONS:handedness_in" +// input_stream: "IMAGE_SIZE:image_size" +// output_stream: "MULTI_LANDMARKS:landmarks_out" +// output_stream: "MULTI_ROIS:rois_out" +// output_stream: "MULTI_WORLD_LANDMARKS:world_landmarks_out" +// output_stream: "MULTI_CLASSIFICATIONS:handedness_out" +// } +class HandLandmarksDeduplicationCalculator : public Node { + public: + constexpr static Input> + kInLandmarks{"MULTI_LANDMARKS"}; + constexpr static Input>::Optional + kInRois{"MULTI_ROIS"}; + constexpr static Input>::Optional + kInWorldLandmarks{"MULTI_WORLD_LANDMARKS"}; + constexpr static Input>::Optional + kInClassifications{"MULTI_CLASSIFICATIONS"}; + constexpr static Input> kInSize{"IMAGE_SIZE"}; + + constexpr static Output> + kOutLandmarks{"MULTI_LANDMARKS"}; + constexpr static Output>::Optional + kOutRois{"MULTI_ROIS"}; + constexpr static Output>::Optional + kOutWorldLandmarks{"MULTI_WORLD_LANDMARKS"}; + constexpr static Output>::Optional + kOutClassifications{"MULTI_CLASSIFICATIONS"}; + MEDIAPIPE_NODE_CONTRACT(kInLandmarks, kInRois, kInWorldLandmarks, + kInClassifications, kInSize, kOutLandmarks, kOutRois, + kOutWorldLandmarks, kOutClassifications); + absl::Status Process(mediapipe::CalculatorContext* cc) override; +}; + +} // namespace mediapipe::api2 + +#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_LANDMARKER_CALCULATORS_HAND_LANDMARKS_DEDUPLICATION_CALCULATOR_H_ diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc new file mode 100644 index 000000000..e610a412e --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -0,0 +1,375 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/core/gate_calculator.pb.h" +#include "mediapipe/calculators/util/collection_has_min_size_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" +#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h" +#include "mediapipe/tasks/cc/core/model_resources_cache.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::utils::DisallowIf; +using ::mediapipe::tasks::core::ModelAssetBundleResources; +using ::mediapipe::tasks::metadata::SetExternalFile; +using ::mediapipe::tasks::vision::hand_detector::proto:: + HandDetectorGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarksDetectorGraphOptions; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS"; +constexpr char kPalmRectsTag[] = "PALM_RECTS"; +constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator"; +constexpr char kHandDetectorTFLiteName[] = "hand_detector.tflite"; +constexpr char kHandLandmarksDetectorTFLiteName[] = + "hand_landmarks_detector.tflite"; + +struct HandLandmarkerOutputs { + Source> landmark_lists; + Source> world_landmark_lists; + Source> hand_rects_next_frame; + Source> handednesses; + Source> palm_rects; + Source> palm_detections; + Source image; +}; + +// Sets the base options in the sub tasks. +absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, + HandLandmarkerGraphOptions* options, + bool is_copy) { + ASSIGN_OR_RETURN(const auto hand_detector_file, + resources.GetModelFile(kHandDetectorTFLiteName)); + auto* hand_detector_graph_options = + options->mutable_hand_detector_graph_options(); + SetExternalFile(hand_detector_file, + hand_detector_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_detector_graph_options->mutable_base_options()->set_use_stream_mode( + options->base_options().use_stream_mode()); + ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, + resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); + auto* hand_landmarks_detector_graph_options = + options->mutable_hand_landmarks_detector_graph_options(); + SetExternalFile(hand_landmarks_detector_file, + hand_landmarks_detector_graph_options->mutable_base_options() + ->mutable_model_asset(), + is_copy); + hand_landmarks_detector_graph_options->mutable_base_options() + ->mutable_acceleration() + ->CopyFrom(options->base_options().acceleration()); + hand_landmarks_detector_graph_options->mutable_base_options() + ->set_use_stream_mode(options->base_options().use_stream_mode()); + return absl::OkStatus(); +} + +} // namespace + +// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand +// landmarks detection. The HandLandmarkerGraph consists of two subgraphs: +// HandDetectorGraph and MultipleHandLandmarksDetectorGraph. +// MultipleHandLandmarksDetectorGraph detects landmarks from bounding boxes +// produced by HandDetectorGraph. HandLandmarkerGraph tracks the landmarks over +// time, and skips the HandDetectorGraph. If the tracking is lost or the detectd +// hands are less than configured max number hands, HandDetectorGraph would be +// triggered to detect hands. +// +// Accepts CPU input images and outputs Landmarks on CPU. +// +// Inputs: +// IMAGE - Image +// Image to perform hand landmarks detection on. +// NORM_RECT - NormalizedRect +// Describes image rotation and region of image to perform landmarks +// detection on. +// +// Outputs: +// LANDMARKS: - std::vector +// Vector of detected hand landmarks. +// WORLD_LANDMARKS - std::vector +// Vector of detected hand landmarks in world coordinates. +// HAND_RECT_NEXT_FRAME - std::vector +// Vector of the predicted rects enclosing the same hand RoI for landmark +// detection on the next frame. +// HANDEDNESS - std::vector +// Vector of classification of handedness. +// PALM_RECTS - std::vector +// Detected palm bounding boxes in normalized coordinates. +// PALM_DETECTIONS - std::vector +// Detected palms with maximum `num_hands` specified in options. +// IMAGE - Image +// The input image that the hand landmarker runs on and has the pixel data +// stored on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" +// input_stream: "IMAGE:image_in" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "LANDMARKS:hand_landmarks" +// output_stream: "WORLD_LANDMARKS:world_hand_landmarks" +// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" +// output_stream: "HANDEDNESS:handedness" +// output_stream: "PALM_RECTS:palm_rects" +// output_stream: "PALM_DETECTIONS:palm_detections" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.hand_landmarker.proto.HandLandmarkerGraphOptions.ext] { +// base_options { +// model_asset { +// file_name: "hand_landmarker.task" +// } +// } +// hand_detector_graph_options { +// base_options { +// model_asset { +// file_name: "palm_detection.tflite" +// } +// } +// min_detection_confidence: 0.5 +// num_hands: 2 +// } +// hand_landmarks_detector_graph_options { +// base_options { +// model_asset { +// file_name: "hand_landmark_lite.tflite" +// } +// } +// min_detection_confidence: 0.5 +// } +// } +// } +// } +class HandLandmarkerGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + Graph graph; + if (sc->Options() + .base_options() + .has_model_asset()) { + ASSIGN_OR_RETURN( + const auto* model_asset_bundle_resources, + CreateModelAssetBundleResources(sc)); + // Copies the file content instead of passing the pointer of file in + // memory if the subgraph model resource service is not available. + MP_RETURN_IF_ERROR(SetSubTaskBaseOptions( + *model_asset_bundle_resources, + sc->MutableOptions(), + !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) + .IsAvailable())); + } + ASSIGN_OR_RETURN(auto hand_landmarker_outputs, + BuildHandLandmarkerGraph( + sc->Options(), + graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); + hand_landmarker_outputs.landmark_lists >> + graph[Output>(kLandmarksTag)]; + hand_landmarker_outputs.world_landmark_lists >> + graph[Output>(kWorldLandmarksTag)]; + hand_landmarker_outputs.hand_rects_next_frame >> + graph[Output>(kHandRectNextFrameTag)]; + hand_landmarker_outputs.handednesses >> + graph[Output>(kHandednessTag)]; + hand_landmarker_outputs.palm_rects >> + graph[Output>(kPalmRectsTag)]; + hand_landmarker_outputs.palm_detections >> + graph[Output>(kPalmDetectionsTag)]; + hand_landmarker_outputs.image >> graph[Output(kImageTag)]; + + // TODO remove when support is fixed. + // As mediapipe GraphBuilder currently doesn't support configuring + // InputStreamInfo, modifying the CalculatorGraphConfig proto directly. + CalculatorGraphConfig config = graph.GetConfig(); + for (int i = 0; i < config.node_size(); ++i) { + if (config.node(i).calculator() == kPreviousLoopbackCalculatorName) { + auto* info = config.mutable_node(i)->add_input_stream_info(); + info->set_tag_index("LOOP"); + info->set_back_edge(true); + break; + } + } + return config; + } + + private: + // Adds a mediapipe hand landmark detection graph into the provided + // builder::Graph instance. + // + // tasks_options: the mediapipe tasks module HandLandmarkerGraphOptions. + // image_in: (mediapipe::Image) stream to run hand landmark detection on. + // graph: the mediapipe graph instance to be updated. + absl::StatusOr BuildHandLandmarkerGraph( + const HandLandmarkerGraphOptions& tasks_options, Source image_in, + Source norm_rect_in, Graph& graph) { + const int max_num_hands = + tasks_options.hand_detector_graph_options().num_hands(); + + auto& previous_loopback = graph.AddNode(kPreviousLoopbackCalculatorName); + image_in >> previous_loopback.In("MAIN"); + auto prev_hand_rects_from_landmarks = + previous_loopback[Output>("PREV_LOOP")]; + + auto& min_size_node = + graph.AddNode("NormalizedRectVectorHasMinSizeCalculator"); + prev_hand_rects_from_landmarks >> min_size_node.In("ITERABLE"); + min_size_node.GetOptions() + .set_min_size(max_num_hands); + auto has_enough_hands = min_size_node.Out("").Cast(); + + auto image_for_hand_detector = + DisallowIf(image_in, has_enough_hands, graph); + auto norm_rect_in_for_hand_detector = + DisallowIf(norm_rect_in, has_enough_hands, graph); + + auto& hand_detector = + graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph"); + hand_detector.GetOptions().CopyFrom( + tasks_options.hand_detector_graph_options()); + image_for_hand_detector >> hand_detector.In("IMAGE"); + norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT"); + auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS"); + + auto& hand_association = graph.AddNode("HandAssociationCalculator"); + hand_association.GetOptions() + .set_min_similarity_threshold(tasks_options.min_tracking_confidence()); + prev_hand_rects_from_landmarks >> + hand_association[Input>::Multiple("")][0]; + hand_rects_from_hand_detector >> + hand_association[Input>::Multiple("")][1]; + auto hand_rects = hand_association.Out(""); + + auto& clip_hand_rects = + graph.AddNode("ClipNormalizedRectVectorSizeCalculator"); + clip_hand_rects.GetOptions() + .set_max_vec_size(max_num_hands); + hand_rects >> clip_hand_rects.In(""); + auto clipped_hand_rects = clip_hand_rects.Out(""); + + auto& hand_landmarks_detector_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "MultipleHandLandmarksDetectorGraph"); + hand_landmarks_detector_graph + .GetOptions() + .CopyFrom(tasks_options.hand_landmarks_detector_graph_options()); + image_in >> hand_landmarks_detector_graph.In("IMAGE"); + clipped_hand_rects >> hand_landmarks_detector_graph.In("HAND_RECT"); + + auto landmarks = hand_landmarks_detector_graph.Out(kLandmarksTag); + auto world_landmarks = + hand_landmarks_detector_graph.Out(kWorldLandmarksTag); + auto hand_rects_for_next_frame = + hand_landmarks_detector_graph.Out(kHandRectNextFrameTag); + auto handedness = hand_landmarks_detector_graph.Out(kHandednessTag); + + auto& image_property = graph.AddNode("ImagePropertiesCalculator"); + image_in >> image_property.In("IMAGE"); + auto image_size = image_property.Out("SIZE"); + + auto& deduplicate = graph.AddNode("HandLandmarksDeduplicationCalculator"); + landmarks >> deduplicate.In("MULTI_LANDMARKS"); + world_landmarks >> deduplicate.In("MULTI_WORLD_LANDMARKS"); + hand_rects_for_next_frame >> deduplicate.In("MULTI_ROIS"); + handedness >> deduplicate.In("MULTI_CLASSIFICATIONS"); + image_size >> deduplicate.In("IMAGE_SIZE"); + + auto filtered_landmarks = + deduplicate[Output>( + "MULTI_LANDMARKS")]; + auto filtered_world_landmarks = + deduplicate[Output>("MULTI_WORLD_LANDMARKS")]; + auto filtered_hand_rects_for_next_frame = + deduplicate[Output>("MULTI_ROIS")]; + auto filtered_handedness = + deduplicate[Output>( + "MULTI_CLASSIFICATIONS")]; + + // Back edge. + filtered_hand_rects_for_next_frame >> previous_loopback.In("LOOP"); + + // TODO: Replace PassThroughCalculator with a calculator that + // converts the pixel data to be stored on the target storage (CPU vs GPU). + auto& pass_through = graph.AddNode("PassThroughCalculator"); + image_in >> pass_through.In(""); + + return {{ + /* landmark_lists= */ filtered_landmarks, + /* world_landmark_lists= */ filtered_world_landmarks, + /* hand_rects_next_frame= */ filtered_hand_rects_for_next_frame, + /* handedness= */ filtered_handedness, + /* palm_rects= */ + hand_detector[Output>(kPalmRectsTag)], + /* palm_detections */ + hand_detector[Output>(kPalmDetectionsTag)], + /* image */ + pass_through[Output("")], + }}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::HandLandmarkerGraph); + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc new file mode 100644 index 000000000..f275486f5 --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -0,0 +1,209 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace hand_landmarker { + +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::hand_landmarker::proto:: + HandLandmarkerGraphOptions; +using ::testing::EqualsProto; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kHandLandmarkerModelBundle[] = "hand_landmarker.task"; +constexpr char kLeftHandsImage[] = "left_hands.jpg"; +constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect_in"; +constexpr char kLandmarksTag[] = "LANDMARKS"; +constexpr char kLandmarksName[] = "landmarks"; +constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS"; +constexpr char kWorldLandmarksName[] = "world_landmarks"; +constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME"; +constexpr char kHandRectNextFrameName[] = "hand_rect_next_frame"; +constexpr char kHandednessTag[] = "HANDEDNESS"; +constexpr char kHandednessName[] = "handedness"; + +// Expected hand landmarks positions, in text proto format. +constexpr char kExpectedLeftUpHandLandmarksFilename[] = + "expected_left_up_hand_landmarks.prototxt"; +constexpr char kExpectedLeftDownHandLandmarksFilename[] = + "expected_left_down_hand_landmarks.prototxt"; +// Same but for the rotated image. +constexpr char kExpectedLeftUpHandRotatedLandmarksFilename[] = + "expected_left_up_hand_rotated_landmarks.prototxt"; +constexpr char kExpectedLeftDownHandRotatedLandmarksFilename[] = + "expected_left_down_hand_rotated_landmarks.prototxt"; + +constexpr float kFullModelFractionDiff = 0.03; // percentage +constexpr float kAbsMargin = 0.03; +constexpr int kMaxNumHands = 2; +constexpr float kMinTrackingConfidence = 0.5; + +NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { + NormalizedLandmarkList expected_landmark_list; + MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename), + &expected_landmark_list, Defaults())); + return expected_landmark_list; +} + +// Helper function to create a Hand Landmarker TaskRunner. +absl::StatusOr> CreateTaskRunner() { + Graph graph; + auto& hand_landmarker_graph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"); + auto& options = + hand_landmarker_graph.GetOptions(); + options.mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, kHandLandmarkerModelBundle)); + options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands); + options.set_min_tracking_confidence(kMinTrackingConfidence); + + graph[Input(kImageTag)].SetName(kImageName) >> + hand_landmarker_graph.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + hand_landmarker_graph.In(kNormRectTag); + hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >> + graph[Output>(kLandmarksTag)]; + hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >> + graph[Output>(kWorldLandmarksTag)]; + hand_landmarker_graph.Out(kHandednessTag).SetName(kHandednessName) >> + graph[Output>(kHandednessTag)]; + hand_landmarker_graph.Out(kHandRectNextFrameTag) + .SetName(kHandRectNextFrameName) >> + graph[Output>(kHandRectNextFrameTag)]; + return TaskRunner::Create( + graph.GetConfig(), absl::make_unique()); +} + +class HandLandmarkerTest : public tflite_shims::testing::Test {}; + +TEST_F(HandLandmarkerTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + const auto& landmarks = (*output_packets)[kLandmarksName] + .Get>(); + ASSERT_EQ(landmarks.size(), kMaxNumHands); + std::vector expected_landmarks = { + GetExpectedLandmarkList(kExpectedLeftUpHandLandmarksFilename), + GetExpectedLandmarkList(kExpectedLeftDownHandLandmarksFilename)}; + + EXPECT_THAT(landmarks[0], + Approximately(Partially(EqualsProto(expected_landmarks[0])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); + EXPECT_THAT(landmarks[1], + Approximately(Partially(EqualsProto(expected_landmarks[1])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); +} + +TEST_F(HandLandmarkerTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + kLeftHandsRotatedImage))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + input_norm_rect.set_rotation(M_PI / 2.0); + MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner()); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + const auto& landmarks = (*output_packets)[kLandmarksName] + .Get>(); + ASSERT_EQ(landmarks.size(), kMaxNumHands); + std::vector expected_landmarks = { + GetExpectedLandmarkList(kExpectedLeftUpHandRotatedLandmarksFilename), + GetExpectedLandmarkList(kExpectedLeftDownHandRotatedLandmarksFilename)}; + + EXPECT_THAT(landmarks[0], + Approximately(Partially(EqualsProto(expected_landmarks[0])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); + EXPECT_THAT(landmarks[1], + Approximately(Partially(EqualsProto(expected_landmarks[1])), + /*margin=*/kAbsMargin, + /*fraction=*/kFullModelFractionDiff)); +} + +} // namespace + +} // namespace hand_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc similarity index 89% rename from mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index fff4ae0d4..23521790d 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -34,12 +34,13 @@ limitations under the License. #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" +#include "mediapipe/tasks/cc/components/utils/gate.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" @@ -48,6 +49,7 @@ limitations under the License. namespace mediapipe { namespace tasks { namespace vision { +namespace hand_landmarker { namespace { @@ -55,9 +57,10 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::utils::AllowIf; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::hand_landmarker::proto:: - HandLandmarkerSubgraphOptions; + HandLandmarksDetectorGraphOptions; using LabelItems = mediapipe::proto_ns::Map; constexpr char kImageTag[] = "IMAGE"; @@ -82,7 +85,6 @@ struct SingleHandLandmarkerOutputs { Source hand_presence; Source hand_presence_score; Source handedness; - Source> image_size; }; struct HandLandmarkerOutputs { @@ -92,10 +94,10 @@ struct HandLandmarkerOutputs { Source> presences; Source> presence_scores; Source> handednesses; - Source> image_size; }; -absl::Status SanityCheckOptions(const HandLandmarkerSubgraphOptions& options) { +absl::Status SanityCheckOptions( + const HandLandmarksDetectorGraphOptions& options) { if (options.min_detection_confidence() < 0 || options.min_detection_confidence() > 1) { return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, @@ -182,8 +184,8 @@ void ConfigureHandRectTransformationCalculator( } // namespace -// A "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" performs hand -// landmark detection. +// A "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph" +// performs hand landmarks detection. // - Accepts CPU input images and outputs Landmark on CPU. // // Inputs: @@ -208,12 +210,11 @@ void ConfigureHandRectTransformationCalculator( // Float value indicates the probability that the hand is present. // HANDEDNESS - ClassificationList // Classification of handedness. -// IMAGE_SIZE - std::vector -// The size of input image. // // Example: // node { -// calculator: "mediapipe.tasks.vision.SingleHandLandmarkerSubgraph" +// calculator: +// "mediapipe.tasks.vision.hand_landmarker.SingleHandLandmarksDetectorGraph" // input_stream: "IMAGE:input_image" // input_stream: "HAND_RECT:hand_rect" // output_stream: "LANDMARKS:hand_landmarks" @@ -221,10 +222,8 @@ void ConfigureHandRectTransformationCalculator( // output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame" // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" -// output_stream: "HANDEDNESS:handedness" -// output_stream: "IMAGE_SIZE:image_size" // options { -// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext] // { // base_options { // model_asset { @@ -235,16 +234,17 @@ void ConfigureHandRectTransformationCalculator( // } // } // } -class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { +class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { - ASSIGN_OR_RETURN(const auto* model_resources, - CreateModelResources(sc)); + ASSIGN_OR_RETURN( + const auto* model_resources, + CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN(auto hand_landmark_detection_outs, - BuildSingleHandLandmarkerSubgraph( - sc->Options(), + BuildSingleHandLandmarksDetectorGraph( + sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input(kHandRectTag)], graph)); hand_landmark_detection_outs.hand_landmarks >> @@ -259,8 +259,6 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { graph[Output(kPresenceScoreTag)]; hand_landmark_detection_outs.handedness >> graph[Output(kHandednessTag)]; - hand_landmark_detection_outs.image_size >> - graph[Output>(kImageSizeTag)]; return graph.GetConfig(); } @@ -269,14 +267,16 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { // Adds a mediapipe hand landmark detection graph into the provided // builder::Graph instance. // - // subgraph_options: the mediapipe tasks module HandLandmarkerSubgraphOptions. - // model_resources: the ModelSources object initialized from a hand landmark + // subgraph_options: the mediapipe tasks module + // HandLandmarksDetectorGraphOptions. model_resources: the ModelSources object + // initialized from a hand landmark // detection model file with model metadata. // image_in: (mediapipe::Image) stream to run hand landmark detection on. // rect: (NormalizedRect) stream to run on the RoI of image. // graph: the mediapipe graph instance to be updated. - absl::StatusOr BuildSingleHandLandmarkerSubgraph( - const HandLandmarkerSubgraphOptions& subgraph_options, + absl::StatusOr + BuildSingleHandLandmarksDetectorGraph( + const HandLandmarksDetectorGraphOptions& subgraph_options, const core::ModelResources& model_resources, Source image_in, Source hand_rect, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); @@ -332,18 +332,7 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { // score of hand presence. auto& tensors_to_hand_presence = graph.AddNode("TensorsToFloatsCalculator"); hand_flag_tensors >> tensors_to_hand_presence.In("TENSORS"); - - // Converts the handedness tensor into a float that represents the - // classification score of handedness. - auto& tensors_to_handedness = - graph.AddNode("TensorsToClassificationCalculator"); - ConfigureTensorsToHandednessCalculator( - &tensors_to_handedness.GetOptions< - mediapipe::TensorsToClassificationCalculatorOptions>()); - handedness_tensors >> tensors_to_handedness.In("TENSORS"); auto hand_presence_score = tensors_to_hand_presence[Output("FLOAT")]; - auto handedness = - tensors_to_handedness[Output("CLASSIFICATIONS")]; // Applies a threshold to the confidence score to determine whether a // hand is present. @@ -354,6 +343,18 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { hand_presence_score >> hand_presence_thresholding.In("FLOAT"); auto hand_presence = hand_presence_thresholding[Output("FLAG")]; + // Converts the handedness tensor into a float that represents the + // classification score of handedness. + auto& tensors_to_handedness = + graph.AddNode("TensorsToClassificationCalculator"); + ConfigureTensorsToHandednessCalculator( + &tensors_to_handedness.GetOptions< + mediapipe::TensorsToClassificationCalculatorOptions>()); + handedness_tensors >> tensors_to_handedness.In("TENSORS"); + auto handedness = AllowIf( + tensors_to_handedness[Output("CLASSIFICATIONS")], + hand_presence, graph); + // Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed // hand image (after image transformation with the FIT scale mode) to the // corresponding locations on the same image with the letterbox removed @@ -371,8 +372,9 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { landmark_letterbox_removal.Out("LANDMARKS") >> landmark_projection.In("NORM_LANDMARKS"); hand_rect >> landmark_projection.In("NORM_RECT"); - auto projected_landmarks = - landmark_projection[Output("NORM_LANDMARKS")]; + auto projected_landmarks = AllowIf( + landmark_projection[Output("NORM_LANDMARKS")], + hand_presence, graph); // Projects the world landmarks from the cropped hand image to the // corresponding locations on the full image before cropping (input to the @@ -383,7 +385,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { world_landmark_projection.In("LANDMARKS"); hand_rect >> world_landmark_projection.In("NORM_RECT"); auto projected_world_landmarks = - world_landmark_projection[Output("LANDMARKS")]; + AllowIf(world_landmark_projection[Output("LANDMARKS")], + hand_presence, graph); // Converts the hand landmarks into a rectangle (normalized by image size) // that encloses the hand. @@ -403,7 +406,8 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { hand_landmarks_to_rect.Out("NORM_RECT") >> hand_rect_transformation.In("NORM_RECT"); auto hand_rect_next_frame = - hand_rect_transformation[Output("")]; + AllowIf(hand_rect_transformation[Output("")], + hand_presence, graph); return {{ /* hand_landmarks= */ projected_landmarks, @@ -412,16 +416,17 @@ class SingleHandLandmarkerSubgraph : public core::ModelTaskGraph { /* hand_presence= */ hand_presence, /* hand_presence_score= */ hand_presence_score, /* handedness= */ handedness, - /* image_size= */ image_size, }}; } }; +// clang-format off REGISTER_MEDIAPIPE_GRAPH( - ::mediapipe::tasks::vision::SingleHandLandmarkerSubgraph); + ::mediapipe::tasks::vision::hand_landmarker::SingleHandLandmarksDetectorGraph); // NOLINT +// clang-format on -// A "mediapipe.tasks.vision.HandLandmarkerSubgraph" performs multi -// hand landmark detection. +// A "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph" +// performs multi hand landmark detection. // - Accepts CPU input image and a vector of hand rect RoIs to detect the // multiple hands landmarks enclosed by the RoIs. Output vectors of // hand landmarks related results, where each element in the vectors @@ -449,12 +454,11 @@ REGISTER_MEDIAPIPE_GRAPH( // Vector of float value indicates the probability that the hand is present. // HANDEDNESS - std::vector // Vector of classification of handedness. -// IMAGE_SIZE - std::vector -// The size of input image. // // Example: // node { -// calculator: "mediapipe.tasks.vision.HandLandmarkerSubgraph" +// calculator: +// "mediapipe.tasks.vision.hand_landmarker.MultipleHandLandmarksDetectorGraph" // input_stream: "IMAGE:input_image" // input_stream: "HAND_RECT:hand_rect" // output_stream: "LANDMARKS:hand_landmarks" @@ -463,9 +467,8 @@ REGISTER_MEDIAPIPE_GRAPH( // output_stream: "PRESENCE:hand_presence" // output_stream: "PRESENCE_SCORE:hand_presence_score" // output_stream: "HANDEDNESS:handedness" -// output_stream: "IMAGE_SIZE:image_size" // options { -// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarkerSubgraphOptions.ext] +// [mediapipe.tasks.vision.hand_landmarker.proto.HandLandmarksDetectorGraphOptions.ext] // { // base_options { // model_asset { @@ -476,15 +479,15 @@ REGISTER_MEDIAPIPE_GRAPH( // } // } // } -class HandLandmarkerSubgraph : public core::ModelTaskGraph { +class MultipleHandLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; ASSIGN_OR_RETURN( auto hand_landmark_detection_outputs, - BuildHandLandmarkerSubgraph( - sc->Options(), + BuildHandLandmarksDetectorGraph( + sc->Options(), graph[Input(kImageTag)], graph[Input>(kHandRectTag)], graph)); hand_landmark_detection_outputs.landmark_lists >> @@ -499,21 +502,20 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { graph[Output>(kPresenceScoreTag)]; hand_landmark_detection_outputs.handednesses >> graph[Output>(kHandednessTag)]; - hand_landmark_detection_outputs.image_size >> - graph[Output>(kImageSizeTag)]; return graph.GetConfig(); } private: - absl::StatusOr BuildHandLandmarkerSubgraph( - const HandLandmarkerSubgraphOptions& subgraph_options, + absl::StatusOr BuildHandLandmarksDetectorGraph( + const HandLandmarksDetectorGraphOptions& subgraph_options, Source image_in, Source> multi_hand_rects, Graph& graph) { - auto& hand_landmark_subgraph = - graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); - hand_landmark_subgraph.GetOptions().CopyFrom( - subgraph_options); + auto& hand_landmark_subgraph = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); + hand_landmark_subgraph.GetOptions() + .CopyFrom(subgraph_options); auto& begin_loop_multi_hand_rects = graph.AddNode("BeginLoopNormalizedRectCalculator"); @@ -533,8 +535,6 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { hand_landmark_subgraph.Out("HAND_RECT_NEXT_FRAME"); auto landmarks = hand_landmark_subgraph.Out("LANDMARKS"); auto world_landmarks = hand_landmark_subgraph.Out("WORLD_LANDMARKS"); - auto image_size = - hand_landmark_subgraph[Output>("IMAGE_SIZE")]; auto& end_loop_handedness = graph.AddNode("EndLoopClassificationListCalculator"); @@ -585,13 +585,16 @@ class HandLandmarkerSubgraph : public core::ModelTaskGraph { /* presences= */ presences, /* presence_scores= */ presence_scores, /* handednesses= */ handednesses, - /* image_size= */ image_size, }}; } }; -REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::HandLandmarkerSubgraph); +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::hand_landmarker::MultipleHandLandmarksDetectorGraph); // NOLINT +// clang-format on +} // namespace hand_landmarker } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc similarity index 96% rename from mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc rename to mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index 1c2bc6da7..d1e928ce7 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_subgraph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -39,12 +39,13 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" -#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.pb.h" +#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" namespace mediapipe { namespace tasks { namespace vision { +namespace hand_landmarker { namespace { using ::file::Defaults; @@ -57,7 +58,7 @@ using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::core::TaskRunner; using ::mediapipe::tasks::vision::DecodeImageFromFile; using ::mediapipe::tasks::vision::hand_landmarker::proto:: - HandLandmarkerSubgraphOptions; + HandLandmarksDetectorGraphOptions; using ::testing::ElementsAreArray; using ::testing::EqualsProto; using ::testing::Pointwise; @@ -112,13 +113,14 @@ absl::StatusOr> CreateSingleHandTaskRunner( absl::string_view model_name) { Graph graph; - auto& hand_landmark_detection = - graph.AddNode("mediapipe.tasks.vision.SingleHandLandmarkerSubgraph"); + auto& hand_landmark_detection = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "SingleHandLandmarksDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); - hand_landmark_detection.GetOptions().Swap( + hand_landmark_detection.GetOptions().Swap( options.get()); graph[Input(kImageTag)].SetName(kImageName) >> @@ -151,13 +153,14 @@ absl::StatusOr> CreateMultiHandTaskRunner( absl::string_view model_name) { Graph graph; - auto& multi_hand_landmark_detection = - graph.AddNode("mediapipe.tasks.vision.HandLandmarkerSubgraph"); + auto& multi_hand_landmark_detection = graph.AddNode( + "mediapipe.tasks.vision.hand_landmarker." + "MultipleHandLandmarksDetectorGraph"); - auto options = std::make_unique(); + auto options = std::make_unique(); options->mutable_base_options()->mutable_model_asset()->set_file_name( JoinPath("./", kTestDataDirectory, model_name)); - multi_hand_landmark_detection.GetOptions() + multi_hand_landmark_detection.GetOptions() .Swap(options.get()); graph[Input(kImageTag)].SetName(kImageName) >> @@ -462,6 +465,7 @@ INSTANTIATE_TEST_SUITE_P( }); } // namespace +} // namespace hand_landmarker } // namespace vision } // namespace tasks } // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD index 8cc984c47..945b12f3e 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/BUILD @@ -21,8 +21,8 @@ package(default_visibility = [ licenses(["notice"]) mediapipe_proto_library( - name = "hand_landmarker_subgraph_options_proto", - srcs = ["hand_landmarker_subgraph_options.proto"], + name = "hand_landmarks_detector_graph_options_proto", + srcs = ["hand_landmarks_detector_graph_options.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -31,13 +31,13 @@ mediapipe_proto_library( ) mediapipe_proto_library( - name = "hand_landmarker_options_proto", - srcs = ["hand_landmarker_options.proto"], + name = "hand_landmarker_graph_options_proto", + srcs = ["hand_landmarker_graph_options.proto"], deps = [ - ":hand_landmarker_subgraph_options_proto", + ":hand_landmarks_detector_graph_options_proto", "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", - "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_options_proto", + "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto new file mode 100644 index 000000000..51e4e129a --- /dev/null +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.proto @@ -0,0 +1,47 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.hand_landmarker.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; +import "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto"; +option java_outer_classname = "HandLandmarkerGraphOptionsProto"; + +message HandLandmarkerGraphOptions { + extend mediapipe.CalculatorOptions { + optional HandLandmarkerGraphOptions ext = 462713202; + } + // Base options for configuring MediaPipe Tasks, such as specifying the model + // asset bundle file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Options for hand detector graph. + optional hand_detector.proto.HandDetectorGraphOptions + hand_detector_graph_options = 2; + + // Options for hand landmarker subgraph. + optional HandLandmarksDetectorGraphOptions + hand_landmarks_detector_graph_options = 3; + + // Minimum confidence for hand landmarks tracking to be considered + // successfully. + optional float min_tracking_confidence = 4 [default = 0.5]; +} diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto similarity index 78% rename from mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto rename to mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto index 9e93384d6..195f6e5cc 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_subgraph_options.proto +++ b/mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.proto @@ -20,19 +20,18 @@ package mediapipe.tasks.vision.hand_landmarker.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -message HandLandmarkerSubgraphOptions { +option java_package = "com.google.mediapipe.tasks.vision.handlandmarker.proto"; +option java_outer_classname = "HandLandmarksDetectorGraphOptionsProto"; + +message HandLandmarksDetectorGraphOptions { extend mediapipe.CalculatorOptions { - optional HandLandmarkerSubgraphOptions ext = 474472470; + optional HandLandmarksDetectorGraphOptions ext = 474472470; } // Base options for configuring MediaPipe Tasks, such as specifying the TfLite // model file with metadata, accelerator options, etc. optional core.proto.BaseOptions base_options = 1; - // The locale to use for display names specified through the TFLite Model - // Metadata, if any. Defaults to English. - optional string display_names_locale = 2 [default = "en"]; - // Minimum confidence value ([0.0, 1.0]) for hand presence score to be // considered successfully detecting a hand in the image. - optional float min_detection_confidence = 3 [default = 0.5]; + optional float min_detection_confidence = 2 [default = 0.5]; } diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index e7c8a6586..3d655cd50 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -26,11 +26,11 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classification_postprocessing", - "//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph", + "//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto", "//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", @@ -50,15 +50,16 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", - "//mediapipe/tasks/cc/components:classifier_options", - "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/processors:classifier_options", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 1e092e85a..8a32758f4 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -26,14 +26,15 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/timestamp.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" -#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -56,18 +57,9 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; -// Builds a NormalizedRect covering the entire image. -NormalizedRect BuildFullImageNormRect() { - NormalizedRect norm_rect; - norm_rect.set_x_center(0.5); - norm_rect.set_y_center(0.5); - norm_rect.set_width(1); - norm_rect.set_height(1); - return norm_rect; -} - // Creates a MediaPipe graph config that contains a subgraph node of // type "ImageClassifierGraph". If the task is running in the live stream mode, // a "FlowLimiterCalculator" will be added to limit the number of frames in @@ -107,8 +99,8 @@ ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { options_proto->mutable_base_options()->set_use_stream_mode( options->running_mode != core::RunningMode::IMAGE); auto classifier_options_proto = - std::make_unique( - components::ConvertClassifierOptionsToProto( + std::make_unique( + components::processors::ConvertClassifierOptionsToProto( &(options->classifier_options))); options_proto->mutable_classifier_options()->Swap( classifier_options_proto.get()); @@ -153,15 +145,16 @@ absl::StatusOr> ImageClassifier::Create( } absl::StatusOr ImageClassifier::Classify( - Image image, std::optional roi) { + Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -172,15 +165,16 @@ absl::StatusOr ImageClassifier::Classify( } absl::StatusOr ImageClassifier::ClassifyForVideo( - Image image, int64 timestamp_ms, std::optional roi) { + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -194,16 +188,17 @@ absl::StatusOr ImageClassifier::ClassifyForVideo( .Get(); } -absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms, - std::optional roi) { +absl::Status ImageClassifier::ClassifyAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index 8ff11413e..de69b7994 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -22,11 +22,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classifier_options.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -51,12 +51,14 @@ struct ImageClassifierOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - components::ClassifierOptions classifier_options; + components::processors::ClassifierOptions classifier_options; // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function, const Image&, int64)> + std::function, + const Image&, int64)> result_callback = nullptr; }; @@ -103,9 +105,16 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { static absl::StatusOr> Create( std::unique_ptr options); - // Performs image classification on the provided single image. Classification - // is performed on the region of interest specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs image classification on the provided single image. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'region_of_interest' field. If not specified, the full image is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the image // running mode. @@ -113,13 +122,21 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. // TODO: describe exact preprocessing steps once // YUVToImageCalculator is integrated. - absl::StatusOr Classify( + absl::StatusOr Classify( mediapipe::Image image, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); - // Performs image classification on the provided video frame. Classification - // is performed on the region of interested specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs image classification on the provided video frame. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'region_of_interest' field. If not specified, the full image is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the video // running mode. @@ -127,14 +144,22 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - absl::StatusOr ClassifyForVideo( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::StatusOr + ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Sends live image data to image classification, and the results will be // available via the "result_callback" provided in the ImageClassifierOptions. - // Classification is performed on the region of interested specified by the - // `roi` argument if provided, or on the entire image otherwise. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing classification, by + // setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform classification, by setting its + // 'region_of_interest' field. If not specified, the full image is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageClassifier is created with the live // stream running mode. @@ -144,19 +169,16 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // sent to the object detector. The input timestamps must be monotonically // increasing. // - // The "result_callback" prvoides + // The "result_callback" provides: // - The classification results as a ClassificationResult object. // - The const reference to the corresponding input image that the image // classifier runs on. Note that the const reference to the image will no // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status ClassifyAsync( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); - - // TODO: add Classify() variants taking a region of interest as - // additional argument. + absl::Status ClassifyAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ImageClassifier when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 0d7b60c99..9a0078c5c 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -22,11 +22,11 @@ limitations under the License. #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing.h" -#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" +#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h" +#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" @@ -43,6 +43,7 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); @@ -152,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Adds postprocessing calculators and connects them to the graph output. auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( - model_resources, task_options.classifier_options(), - &postprocessing.GetOptions< - tasks::components::ClassificationPostprocessingOptions>())); + "mediapipe.tasks.components.processors." + "ClassificationPostprocessingGraph"); + MP_RETURN_IF_ERROR( + components::processors::ConfigureClassificationPostprocessingGraph( + model_resources, task_options.classifier_options(), + &postprocessing + .GetOptions())); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); // Outputs the aggregated classification result as the subgraph output diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index edbb851c0..0c45122c0 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_classifier/image_classifier.h" +#include #include #include #include @@ -26,14 +27,15 @@ limitations under the License. #include "absl/strings/str_format.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" @@ -48,6 +50,11 @@ namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; +using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::proto::Classifications; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -205,7 +212,7 @@ TEST_F(CreateTest, FailsWithMissingModel) { EXPECT_THAT( image_classifier.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(image_classifier.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); @@ -543,18 +550,159 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Region-of-interest around the soccer ball. + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; - MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image, roi)); + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 3; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + // Specify a 90° anti-clockwise rotation. + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); + + // Results differ slightly from the non-rotated image, but that's expected + // as models are very sensitive to the slightest numerical differences + // introduced by the rotation and JPG encoding. + ExpectApproximatelyEqual(results, ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 934 + score: 0.6371766 + category_name: "cheeseburger" + } + categories { + index: 963 + score: 0.049443405 + category_name: "meat loaf" + } + categories { + index: 925 + score: 0.047918003 + category_name: "guacamole" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + options->classifier_options.max_results = 1; + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + // Region-of-interest around the chair, with 90° anti-clockwise rotation. + Rect roi{/*left=*/0.006, /*top=*/0.1763, /*right=*/0.5702, /*bottom=*/0.3049}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/-90}; + + MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( + image, image_processing_options)); + + ExpectApproximatelyEqual(results, + ParseTextProtoOrDie( + R"pb(classifications { + entries { + categories { + index: 560 + score: 0.6522213 + category_name: "folding chair" + } + timestamp_ms: 0 + } + head_index: 0 + head_name: "probability" + })pb")); +} + +// Testing all these once with ImageClassifier. +TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "multi_objects.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, + ImageClassifier::Create(std::move(options))); + + // Invalid: left > right. + Rect roi{/*left=*/0.9, /*top=*/0, /*right=*/0.1, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/0}; + auto results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect with left < right and top < bottom")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: top > bottom. + roi = {/*left=*/0, /*top=*/0.9, /*right=*/1, /*bottom=*/0.1}; + image_processing_options = {roi, + /*rotation_degrees=*/0}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect with left < right and top < bottom")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: coordinates out of [0,1] range. + roi = {/*left=*/-0.1, /*top=*/0, /*right=*/1, /*bottom=*/1}; + image_processing_options = {roi, + /*rotation_degrees=*/0}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected Rect values to be in [0,1]")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); + + // Invalid: rotation not a multiple of 90°. + image_processing_options = {/*region_of_interest=*/std::nullopt, + /*rotation_degrees=*/1}; + results = image_classifier->Classify(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("Expected rotation to be a multiple of 90°")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { @@ -643,16 +791,15 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { options->classifier_options.max_results = 1; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + // Region-of-interest around the soccer ball. + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK_AND_ASSIGN(auto results, - image_classifier->ClassifyForVideo(image, i, roi)); + MP_ASSERT_OK_AND_ASSIGN( + auto results, + image_classifier->ClassifyForVideo(image, i, image_processing_options)); ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); } MP_ASSERT_OK(image_classifier->Close()); @@ -787,15 +934,13 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { }; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); - // NormalizedRect around the soccer ball. - NormalizedRect roi; - roi.set_x_center(0.532); - roi.set_y_center(0.521); - roi.set_width(0.164); - roi.set_height(0.427); + // Crop around the soccer ball. + Rect roi{/*left=*/0.45, /*top=*/0.3075, /*right=*/0.614, /*bottom=*/0.7345}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; for (int i = 0; i < iterations; ++i) { - MP_ASSERT_OK(image_classifier->ClassifyAsync(image, i, roi)); + MP_ASSERT_OK( + image_classifier->ClassifyAsync(image, i, image_processing_options)); } MP_ASSERT_OK(image_classifier->Close()); diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD index a6f5791e3..29638bebd 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/BUILD @@ -24,7 +24,7 @@ mediapipe_proto_library( deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", - "//mediapipe/tasks/cc/components/proto:classifier_options_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto index 3da047110..76315e230 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto +++ b/mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.proto @@ -18,9 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.vision.image_classifier.proto; import "mediapipe/framework/calculator.proto"; -import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; +import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; +option java_package = "com.google.mediapipe.tasks.vision.imageclassifier.proto"; +option java_outer_classname = "ImageClassifierGraphOptionsProto"; + message ImageClassifierGraphOptions { extend mediapipe.CalculatorOptions { optional ImageClassifierGraphOptions ext = 456383383; @@ -31,5 +34,5 @@ message ImageClassifierGraphOptions { // Options for configuring the classifier behavior, such as score threshold, // number of results, etc. - optional components.proto.ClassifierOptions classifier_options = 2; + optional components.processors.proto.ClassifierOptions classifier_options = 2; } diff --git a/mediapipe/tasks/cc/vision/image_embedder/BUILD b/mediapipe/tasks/cc/vision/image_embedder/BUILD index e619b8d1b..0f63f87e4 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/BUILD +++ b/mediapipe/tasks/cc/vision/image_embedder/BUILD @@ -58,6 +58,7 @@ cc_library( "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc index 24fd2862c..1dc316305 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/image_embedder/proto/image_embedder_graph_options.pb.h" @@ -58,16 +59,6 @@ using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::vision::image_embedder::proto:: ImageEmbedderGraphOptions; -// Builds a NormalizedRect covering the entire image. -NormalizedRect BuildFullImageNormRect() { - NormalizedRect norm_rect; - norm_rect.set_x_center(0.5); - norm_rect.set_y_center(0.5); - norm_rect.set_width(1); - norm_rect.set_height(1); - return norm_rect; -} - // Creates a MediaPipe graph config that contains a single node of type // "mediapipe.tasks.vision.image_embedder.ImageEmbedderGraph". If the task is // running in the live stream mode, a "FlowLimiterCalculator" will be added to @@ -148,15 +139,16 @@ absl::StatusOr> ImageEmbedder::Create( } absl::StatusOr ImageEmbedder::Embed( - Image image, std::optional roi) { + Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessImageData( @@ -167,15 +159,16 @@ absl::StatusOr ImageEmbedder::Embed( } absl::StatusOr ImageEmbedder::EmbedForVideo( - Image image, int64 timestamp_ms, std::optional roi) { + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( @@ -188,16 +181,17 @@ absl::StatusOr ImageEmbedder::EmbedForVideo( return output_packets[kEmbeddingResultStreamName].Get(); } -absl::Status ImageEmbedder::EmbedAsync(Image image, int64 timestamp_ms, - std::optional roi) { +absl::Status ImageEmbedder::EmbedAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, "GPU input images are currently not supported.", MediaPipeTasksStatus::kRunnerUnexpectedInputError); } - NormalizedRect norm_rect = - roi.has_value() ? roi.value() : BuildFullImageNormRect(); + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h index 13f4702d1..3a2a1dbee 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder.h @@ -21,11 +21,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" #include "mediapipe/tasks/cc/components/embedder_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -88,9 +88,17 @@ class ImageEmbedder : core::BaseVisionTaskApi { static absl::StatusOr> Create( std::unique_ptr options); - // Performs embedding extraction on the provided single image. Extraction - // is performed on the region of interest specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs embedding extraction on the provided single image. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the image // running mode. @@ -98,11 +106,20 @@ class ImageEmbedder : core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. absl::StatusOr Embed( mediapipe::Image image, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); - // Performs embedding extraction on the provided video frame. Extraction - // is performed on the region of interested specified by the `roi` argument if - // provided, or on the entire image otherwise. + // Performs embedding extraction on the provided video frame. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the video // running mode. @@ -112,12 +129,21 @@ class ImageEmbedder : core::BaseVisionTaskApi { // must be monotonically increasing. absl::StatusOr EmbedForVideo( mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + std::optional image_processing_options = + std::nullopt); // Sends live image data to embedder, and the results will be available via - // the "result_callback" provided in the ImageEmbedderOptions. Embedding - // extraction is performed on the region of interested specified by the `roi` - // argument if provided, or on the entire image otherwise. + // the "result_callback" provided in the ImageEmbedderOptions. + // + // The optional 'image_processing_options' parameter can be used to specify: + // - the rotation to apply to the image before performing embedding + // extraction, by setting its 'rotation_degrees' field. + // and/or + // - the region-of-interest on which to perform embedding extraction, by + // setting its 'region_of_interest' field. If not specified, the full image + // is used. + // If both are specified, the crop around the region-of-interest is extracted + // first, then the specified rotation is applied to the crop. // // Only use this method when the ImageEmbedder is created with the live // stream running mode. @@ -135,9 +161,9 @@ class ImageEmbedder : core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status EmbedAsync( - mediapipe::Image image, int64 timestamp_ms, - std::optional roi = std::nullopt); + absl::Status EmbedAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ImageEmbedder when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 08a0d6a25..386b6c8eb 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" @@ -42,7 +41,9 @@ namespace image_embedder { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -140,7 +141,7 @@ TEST_F(CreateTest, FailsWithMissingModel) { EXPECT_THAT( image_embedder.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(image_embedder.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); @@ -326,16 +327,14 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN( Image crop, DecodeImageFromFile( JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); - // Bounding box in "burger.jpg" corresponding to "burger_crop.jpg". - NormalizedRect roi; - roi.set_x_center(200.0 / 480); - roi.set_y_center(0.5); - roi.set_width(400.0 / 480); - roi.set_height(1.0f); + // Region-of-interest in "burger.jpg" corresponding to "burger_crop.jpg". + Rect roi{/*left=*/0, /*top=*/0, /*right=*/0.833333, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; // Extract both embeddings. - MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, - image_embedder->Embed(image, roi)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& image_result, + image_embedder->Embed(image, image_processing_options)); MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, image_embedder->Embed(crop)); @@ -351,6 +350,77 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + // Load images: one is a rotated version of the other. + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "burger.jpg"))); + MP_ASSERT_OK_AND_ASSIGN(Image rotated, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& image_result, + image_embedder->Embed(image)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& rotated_result, + image_embedder->Embed(rotated, image_processing_options)); + + // Check results. + CheckMobileNetV3Result(image_result, false); + CheckMobileNetV3Result(rotated_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(image_result.embeddings(0).entries(0), + rotated_result.embeddings(0).entries(0))); + double expected_similarity = 0.572265; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + +TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileNetV3Embedder); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_embedder, + ImageEmbedder::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + Image crop, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "burger_crop.jpg"))); + MP_ASSERT_OK_AND_ASSIGN(Image rotated, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "burger_rotated.jpg"))); + // Region-of-interest corresponding to burger_crop.jpg. + Rect roi{/*left=*/0, /*top=*/0, /*right=*/1, /*bottom=*/0.8333333}; + ImageProcessingOptions image_processing_options{roi, + /*rotation_degrees=*/-90}; + + // Extract both embeddings. + MP_ASSERT_OK_AND_ASSIGN(const EmbeddingResult& crop_result, + image_embedder->Embed(crop)); + MP_ASSERT_OK_AND_ASSIGN( + const EmbeddingResult& rotated_result, + image_embedder->Embed(rotated, image_processing_options)); + + // Check results. + CheckMobileNetV3Result(crop_result, false); + CheckMobileNetV3Result(rotated_result, false); + // CheckCosineSimilarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, + ImageEmbedder::CosineSimilarity(crop_result.embeddings(0).entries(0), + rotated_result.embeddings(0).entries(0))); + double expected_similarity = 0.62838; + EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index 6af733657..81cd43e34 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -24,16 +24,17 @@ cc_library( ":image_segmenter_graph", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) @@ -49,6 +50,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", @@ -73,19 +75,4 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "image_segmenter_op_resolvers", - srcs = ["image_segmenter_op_resolvers.cc"], - hdrs = ["image_segmenter_op_resolvers.h"], - deps = [ - "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", - "//mediapipe/util/tflite/operations:max_pool_argmax", - "//mediapipe/util/tflite/operations:max_unpooling", - "//mediapipe/util/tflite/operations:transform_landmarks", - "//mediapipe/util/tflite/operations:transform_tensor_bilinear", - "//mediapipe/util/tflite/operations:transpose_conv_bias", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - ], -) - # TODO: This test fails in OSS diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 84ceea88a..209ee0df3 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -17,8 +17,10 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" @@ -32,6 +34,8 @@ constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ImageSegmenterGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -51,15 +55,18 @@ CalculatorGraphConfig CreateGraphConfig( auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap(options.get()); graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> graph.Out(kGroupedSegmentationTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator( - graph, task_subgraph, {kImageTag}, kGroupedSegmentationTag); + return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, + {kImageTag, kNormRectTag}, + kGroupedSegmentationTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); return graph.GetConfig(); } @@ -139,47 +146,68 @@ absl::StatusOr> ImageSegmenter::Create( } absl::StatusOr> ImageSegmenter::Segment( - mediapipe::Image image) { + mediapipe::Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, - ProcessImageData({{kImageInStreamName, - mediapipe::MakePacket(std::move(image))}})); + ProcessImageData( + {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); return output_packets[kSegmentationStreamName].Get>(); } absl::StatusOr> ImageSegmenter::SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms) { + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, ProcessVideoData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); return output_packets[kSegmentationStreamName].Get>(); } -absl::Status ImageSegmenter::SegmentAsync(Image image, int64 timestamp_ms) { +absl::Status ImageSegmenter::SegmentAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index ce9cb104c..54269ec0e 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -25,8 +25,8 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/kernels/register.h" namespace mediapipe { @@ -117,14 +117,21 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // running mode. // // The image can be of any size with format RGB or RGBA. - // TODO: Describes how the input image will be preprocessed - // after the yuv support is implemented. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. // // If the output_type is CATEGORY_MASK, the returned vector of images is // per-category segmented image mask. // If the output_type is CONFIDENCE_MASK, the returned vector of images // contains only one confidence image mask. - absl::StatusOr> Segment(mediapipe::Image image); + absl::StatusOr> Segment( + mediapipe::Image image, + std::optional image_processing_options = + std::nullopt); // Performs image segmentation on the provided video frame. // Only use this method when the ImageSegmenter is created with the video @@ -134,12 +141,20 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // If the output_type is CATEGORY_MASK, the returned vector of images is // per-category segmented image mask. // If the output_type is CONFIDENCE_MASK, the returned vector of images // contains only one confidence image mask. absl::StatusOr> SegmentForVideo( - mediapipe::Image image, int64 timestamp_ms); + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options = + std::nullopt); // Sends live image data to perform image segmentation, and the results will // be available via the "result_callback" provided in the @@ -151,6 +166,12 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // sent to the image segmenter. The input timestamps must be monotonically // increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // The "result_callback" prvoides // - A vector of segmented image masks. // If the output_type is CATEGORY_MASK, the returned vector of images is @@ -162,7 +183,9 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // no longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms); + absl::Status SegmentAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index 1678dd083..629b940aa 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" @@ -62,6 +63,7 @@ using LabelItems = mediapipe::proto_ns::Map; constexpr char kSegmentationTag[] = "SEGMENTATION"; constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; @@ -159,6 +161,10 @@ absl::StatusOr GetOutputTensor( // Inputs: // IMAGE - Image // Image to perform segmentation on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection +// on. +// @Optional: rect covering the whole image is used if not specified. // // Outputs: // SEGMENTATION - mediapipe::Image @Multiple @@ -196,10 +202,12 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { ASSIGN_OR_RETURN(const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN(auto output_streams, - BuildSegmentationTask( - sc->Options(), *model_resources, - graph[Input(kImageTag)], graph)); + ASSIGN_OR_RETURN( + auto output_streams, + BuildSegmentationTask( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); auto& merge_images_to_vector = graph.AddNode("MergeImagesToVectorCalculator"); @@ -228,7 +236,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { absl::StatusOr BuildSegmentationTask( const ImageSegmenterOptions& task_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Adds preprocessing calculators and connects them to the graph input image @@ -240,6 +248,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 2f1c26a79..07235563b 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -29,9 +29,10 @@ limitations under the License. #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -45,6 +46,8 @@ namespace { using ::mediapipe::Image; using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -192,7 +195,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { EXPECT_THAT( segmenter_or.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); @@ -238,7 +241,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); - MP_ASSERT_OK_AND_ASSIGN(auto results, segmenter->Segment(image)); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); EXPECT_EQ(confidence_masks.size(), 21); @@ -254,14 +256,67 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile( + JoinPath("./", kTestDataDirectory, "cat_rotated.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + EXPECT_EQ(confidence_masks.size(), 21); + + cv::Mat expected_mask = + cv::imread(JoinPath("./", kTestDataDirectory, "cat_rotated_mask.jpg"), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + + // Cat category index 8. + cv::Mat cat_mask = mediapipe::formats::MatView( + confidence_masks[8].GetImageFrameSharedPtr().get()); + EXPECT_THAT(cat_mask, + SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg"))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + options->activation = ImageSegmenterOptions::Activation::SOFTMAX; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = segmenter->Segment(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { Image image = GetSRGBImage(JoinPath("./", kTestDataDirectory, "mozart_square.jpg")); auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::SOFTMAX; @@ -290,8 +345,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { auto options = std::make_unique(); options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); - options->base_options.op_resolver = - absl::make_unique(); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, diff --git a/mediapipe/tasks/cc/vision/object_detector/BUILD b/mediapipe/tasks/cc/vision/object_detector/BUILD index 6a9a25fc1..8220d8b7f 100644 --- a/mediapipe/tasks/cc/vision/object_detector/BUILD +++ b/mediapipe/tasks/cc/vision/object_detector/BUILD @@ -33,6 +33,7 @@ cc_library( "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components:image_preprocessing", @@ -66,6 +67,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", "//mediapipe/tasks/cc/core:base_options", @@ -73,6 +75,7 @@ cc_library( "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc index 8b7473d48..dd19237ff 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.cc @@ -28,11 +28,13 @@ limitations under the License. #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.pb.h" @@ -48,6 +50,8 @@ constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectName[] = "norm_rect_in"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.ObjectDetectorGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; @@ -64,6 +68,7 @@ CalculatorGraphConfig CreateGraphConfig( bool enable_flow_limiting) { api2::builder::Graph graph; graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( options_proto.get()); @@ -72,10 +77,11 @@ CalculatorGraphConfig CreateGraphConfig( task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, - {kImageTag}, kDetectionsTag); + return tasks::core::AddFlowLimiterCalculator( + graph, task_subgraph, {kImageTag, kNormRectTag}, kDetectionsTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); return graph.GetConfig(); } @@ -139,46 +145,67 @@ absl::StatusOr> ObjectDetector::Create( } absl::StatusOr> ObjectDetector::Detect( - mediapipe::Image image) { - if (image.UsesGpu()) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat("GPU input images are currently not supported."), - MediaPipeTasksStatus::kRunnerUnexpectedInputError); - } - ASSIGN_OR_RETURN(auto output_packets, - ProcessImageData({{kImageInStreamName, - MakePacket(std::move(image))}})); - return output_packets[kDetectionsOutStreamName].Get>(); -} - -absl::StatusOr> ObjectDetector::DetectForVideo( - mediapipe::Image image, int64 timestamp_ms) { + mediapipe::Image image, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); ASSIGN_OR_RETURN( auto output_packets, - ProcessVideoData( - {{kImageInStreamName, - MakePacket(std::move(image)) - .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectName, MakePacket(std::move(norm_rect))}})); return output_packets[kDetectionsOutStreamName].Get>(); } -absl::Status ObjectDetector::DetectAsync(Image image, int64 timestamp_ms) { +absl::StatusOr> ObjectDetector::DetectForVideo( + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options) { if (image.UsesGpu()) { return CreateStatusWithPayload( absl::StatusCode::kInvalidArgument, absl::StrCat("GPU input images are currently not supported."), MediaPipeTasksStatus::kRunnerUnexpectedInputError); } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + return output_packets[kDetectionsOutStreamName].Get>(); +} + +absl::Status ObjectDetector::DetectAsync( + Image image, int64 timestamp_ms, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); return SendLiveStreamData( {{kImageInStreamName, MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectName, + MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector.h b/mediapipe/tasks/cc/vision/object_detector/object_detector.h index 0fa1b087b..44ce68ed9 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector.h +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "mediapipe/framework/formats/image.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" namespace mediapipe { @@ -151,6 +153,12 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // TODO: Describes how the input image will be preprocessed // after the yuv support is implemented. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // For CPU images, the returned bounding boxes are expressed in the // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the @@ -158,7 +166,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // TODO: Describes the output bounding boxes for gpu input // images after enabling the gpu support in MediaPipe Tasks. absl::StatusOr> Detect( - mediapipe::Image image); + mediapipe::Image image, + std::optional image_processing_options = + std::nullopt); // Performs object detection on the provided video frame. // Only use this method when the ObjectDetector is created with the video @@ -168,12 +178,20 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // // For CPU images, the returned bounding boxes are expressed in the // unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the // underlying image data. absl::StatusOr> DetectForVideo( - mediapipe::Image image, int64 timestamp_ms); + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options = + std::nullopt); // Sends live image data to perform object detection, and the results will be // available via the "result_callback" provided in the ObjectDetectorOptions. @@ -185,7 +203,13 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // sent to the object detector. The input timestamps must be monotonically // increasing. // - // The "result_callback" prvoides + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // The "result_callback" provides // - A vector of detections, each has a bounding box that is expressed in // the unrotated input frame of reference coordinates system, i.e. in `[0, // image_width) x [0, image_height)`, which are the dimensions of the @@ -195,7 +219,9 @@ class ObjectDetector : tasks::vision::core::BaseVisionTaskApi { // longer be valid when the callback returns. To access the image data // outside of the callback, callers need to make a copy of the image. // - The input timestamp in milliseconds. - absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms); + absl::Status DetectAsync(mediapipe::Image image, int64 timestamp_ms, + std::optional + image_processing_options = std::nullopt); // Shuts down the ObjectDetector when all works are done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc index b0533e469..07e912cfc 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_graph.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" @@ -87,6 +88,7 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; constexpr char kImageTag[] = "IMAGE"; constexpr char kIndicesTag[] = "INDICES"; constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; constexpr char kScoresTag[] = "SCORES"; @@ -457,12 +459,18 @@ void ConfigureTensorsToDetectionsCalculator( // Inputs: // IMAGE - Image // Image to perform detection on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection +// on. +// @Optional: rect covering the whole image is used if not specified. // // Outputs: // DETECTIONS - std::vector // Detected objects with bounding box in pixel units. // IMAGE - mediapipe::Image // The image that object detection runs on. +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. // // Example: // node { @@ -494,9 +502,10 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { Graph graph; ASSIGN_OR_RETURN( auto output_streams, - BuildObjectDetectionTask(sc->Options(), - *model_resources, - graph[Input(kImageTag)], graph)); + BuildObjectDetectionTask( + sc->Options(), *model_resources, + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], graph)); output_streams.detections >> graph[Output>(kDetectionsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -519,7 +528,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { absl::StatusOr BuildObjectDetectionTask( const ObjectDetectorOptionsProto& task_options, const core::ModelResources& model_resources, Source image_in, - Graph& graph) { + Source norm_rect_in, Graph& graph) { MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); // Checks that the model has 4 outputs. auto& model = *model_resources.GetTfLiteModel(); @@ -559,6 +568,7 @@ class ObjectDetectorGraph : public core::ModelTaskGraph { &preprocessing .GetOptions())); image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 463c92566..1747685dd 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/object_detector/object_detector.h" +#include #include #include #include @@ -34,6 +35,8 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/c/common.h" @@ -62,6 +65,8 @@ namespace vision { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Rect; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -195,7 +200,7 @@ TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { // interpreter errors (e.g., "Encountered unresolved custom op"). EXPECT_EQ(object_detector.status().code(), absl::StatusCode::kInternal); EXPECT_THAT(object_detector.status().message(), - HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); + HasSubstr("interpreter->AllocateTensors() == kTfLiteOk")); } TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { @@ -208,7 +213,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { EXPECT_THAT( object_detector.status().message(), HasSubstr("ExternalFile must specify at least one of 'file_content', " - "'file_name' or 'file_descriptor_meta'.")); + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); EXPECT_THAT(object_detector.status().GetPayload(kMediaPipeTasksPayload), Optional(absl::Cord(absl::StrCat( MediaPipeTasksStatus::kRunnerInitializationError)))); @@ -519,6 +524,55 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { ExpectApproximatelyEqual(results, {full_expected_results[3]}); } +TEST_F(ImageModeTest, SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs_rotated.jpg"))); + auto options = std::make_unique(); + options->max_results = 1; + options->category_allowlist.push_back("cat"); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + MP_ASSERT_OK_AND_ASSIGN( + auto results, object_detector->Detect(image, image_processing_options)); + MP_ASSERT_OK(object_detector->Close()); + ExpectApproximatelyEqual( + results, {ParseTextProtoOrDie(R"pb( + label: "cat" + score: 0.7109375 + location_data { + format: BOUNDING_BOX + bounding_box { xmin: 0 ymin: 622 width: 436 height: 276 } + })pb")}); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN(Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + "cats_and_dogs.jpg"))); + auto options = std::make_unique(); + options->max_results = 1; + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, + ObjectDetector::Create(std::move(options))); + Rect roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = object_detector->Detect(image, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + class VideoModeTest : public tflite_shims::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { diff --git a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto index 37edab1d9..cba58ace8 100644 --- a/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto +++ b/mediapipe/tasks/cc/vision/object_detector/proto/object_detector_options.proto @@ -20,7 +20,7 @@ package mediapipe.tasks.vision.object_detector.proto; import "mediapipe/framework/calculator.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; -option java_package = "com.google.mediapipe.tasks.vision.objectdetector"; +option java_package = "com.google.mediapipe.tasks.vision.objectdetector.proto"; option java_outer_classname = "ObjectDetectorOptionsProto"; message ObjectDetectorOptions { diff --git a/mediapipe/tasks/cc/vision/utils/BUILD b/mediapipe/tasks/cc/vision/utils/BUILD index 3e5cfd2e9..fda33bea5 100644 --- a/mediapipe/tasks/cc/vision/utils/BUILD +++ b/mediapipe/tasks/cc/vision/utils/BUILD @@ -79,3 +79,30 @@ cc_library( "@stblib//:stb_image", ], ) + +cc_library( + name = "landmarks_duplicates_finder", + hdrs = ["landmarks_duplicates_finder.h"], + deps = [ + "//mediapipe/framework/formats:landmark_cc_proto", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "landmarks_utils", + srcs = ["landmarks_utils.cc"], + hdrs = ["landmarks_utils.h"], + deps = ["//mediapipe/tasks/cc/components/containers:rect"], +) + +cc_test( + name = "landmarks_utils_test", + srcs = ["landmarks_utils_test.cc"], + deps = [ + ":landmarks_utils", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/components/containers:rect", + ], +) diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h b/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h new file mode 100644 index 000000000..e1632e6f0 --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/landmark.pb.h" + +namespace mediapipe::tasks::vision::utils { + +class DuplicatesFinder { + public: + virtual ~DuplicatesFinder() = default; + // Returns indices of landmark lists to remove to make @multi_landmarks + // contain different enough (depending on the implementation) landmark lists + // only. + virtual absl::StatusOr> FindDuplicates( + const std::vector& multi_landmarks, + int input_width, int input_height) = 0; +}; + +} // namespace mediapipe::tasks::vision::utils + +#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_DUPLICATES_FINDER_H_ diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc new file mode 100644 index 000000000..2ce9e2454 --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.cc @@ -0,0 +1,50 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" + +#include +#include + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::vision::utils { + +using ::mediapipe::tasks::components::containers::Rect; + +float CalculateArea(const Rect& rect) { + return (rect.right - rect.left) * (rect.bottom - rect.top); +} + +float CalculateIntersectionArea(const Rect& a, const Rect& b) { + const float intersection_left = std::max(a.left, b.left); + const float intersection_top = std::max(a.top, b.top); + const float intersection_right = std::min(a.right, b.right); + const float intersection_bottom = std::min(a.bottom, b.bottom); + + return std::max(intersection_bottom - intersection_top, 0.0) * + std::max(intersection_right - intersection_left, 0.0); +} + +float CalculateIOU(const Rect& a, const Rect& b) { + const float area_a = CalculateArea(a); + const float area_b = CalculateArea(b); + if (area_a <= 0 || area_b <= 0) return 0.0; + + const float intersection_area = CalculateIntersectionArea(a, b); + return intersection_area / (area_a + area_b - intersection_area); +} + +} // namespace mediapipe::tasks::vision::utils diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils.h b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h new file mode 100644 index 000000000..73114d2ef --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils.h @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ + +#include +#include +#include +#include +#include + +#include "mediapipe/tasks/cc/components/containers/rect.h" + +namespace mediapipe::tasks::vision::utils { + +// Calculates intersection over union for two bounds. +float CalculateIOU(const components::containers::Rect& a, + const components::containers::Rect& b); + +// Calculates area for face bound +float CalculateArea(const components::containers::Rect& rect); + +// Calucates intersection area of two face bounds +float CalculateIntersectionArea(const components::containers::Rect& a, + const components::containers::Rect& b); +} // namespace mediapipe::tasks::vision::utils + +#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_ diff --git a/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc b/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc new file mode 100644 index 000000000..c30a5225b --- /dev/null +++ b/mediapipe/tasks/cc/vision/utils/landmarks_utils_test.cc @@ -0,0 +1,41 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe::tasks::vision::utils { +namespace { + +TEST(LandmarkUtilsTest, CalculateIOU) { + // Do not intersect + EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 2, 3, 3})); + // No x intersection + EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {2, 0, 3, 1})); + // No y intersection + EXPECT_EQ(0, CalculateIOU({0, 0, 1, 1}, {0, 2, 1, 3})); + // Full intersection + EXPECT_EQ(1, CalculateIOU({0, 0, 2, 2}, {0, 0, 2, 2})); + + // Union is 4 intersection is 1 + EXPECT_EQ(0.25, CalculateIOU({0, 0, 3, 1}, {2, 0, 4, 1})); + + // Same in by y + EXPECT_EQ(0.25, CalculateIOU({0, 0, 1, 3}, {0, 2, 1, 4})); +} +} // namespace +} // namespace mediapipe::tasks::vision::utils diff --git a/mediapipe/tasks/examples/android/BUILD b/mediapipe/tasks/examples/android/BUILD new file mode 100644 index 000000000..c07af2d2c --- /dev/null +++ b/mediapipe/tasks/examples/android/BUILD @@ -0,0 +1,21 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +filegroup( + name = "resource_files", + srcs = glob(["res/**"]), + visibility = ["//mediapipe/tasks/examples/android:__subpackages__"], +) diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml b/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml new file mode 100644 index 000000000..5c53dc269 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/AndroidManifest.xml @@ -0,0 +1,37 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD similarity index 53% rename from mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD rename to mediapipe/tasks/examples/android/objectdetector/src/main/BUILD index 8ba2705eb..89c1edcb3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/objectdetector/BUILD +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/BUILD @@ -12,33 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_visibility = ["//mediapipe/tasks:internal"]) - licenses(["notice"]) -android_library( +package(default_visibility = ["//visibility:private"]) + +android_binary( name = "objectdetector", - srcs = [ - "ObjectDetectionResult.java", - "ObjectDetector.java", + srcs = glob(["**/*.java"]), + assets = [ + "//mediapipe/tasks/testdata/vision:test_models", ], - javacopts = [ - "-Xep:AndroidJdkLibsChecker:OFF", - ], - manifest = ":AndroidManifest.xml", + assets_dir = "", + custom_package = "com.google.mediapipe.tasks.examples.objectdetector", + manifest = "AndroidManifest.xml", + manifest_values = { + "applicationId": "com.google.mediapipe.tasks.examples.objectdetector", + }, + multidex = "native", + resource_files = ["//mediapipe/tasks/examples/android:resource_files"], deps = [ - "//mediapipe/framework:calculator_options_java_proto_lite", - "//mediapipe/framework/formats:detection_java_proto_lite", - "//mediapipe/framework/formats:location_data_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", - "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core", - "//third_party:autovalue", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:core", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/vision:objectdetector", + "//third_party:androidx_appcompat", + "//third_party:androidx_constraint_layout", + "//third_party:opencv", + "@maven//:androidx_activity_activity", + "@maven//:androidx_concurrent_concurrent_futures", + "@maven//:androidx_exifinterface_exifinterface", + "@maven//:androidx_fragment_fragment", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java new file mode 100644 index 000000000..18c010a00 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/MainActivity.java @@ -0,0 +1,239 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.examples.objectdetector; + +import android.content.Intent; +import android.graphics.Bitmap; +import android.media.MediaMetadataRetriever; +import android.os.Bundle; +import android.provider.MediaStore; +import androidx.appcompat.app.AppCompatActivity; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.FrameLayout; +import androidx.activity.result.ActivityResultLauncher; +import androidx.activity.result.contract.ActivityResultContracts; +import androidx.exifinterface.media.ExifInterface; +// ContentResolver dependency +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions; +import java.io.IOException; +import java.io.InputStream; + +/** Main activity of MediaPipe Task Object Detector reference app. */ +public class MainActivity extends AppCompatActivity { + private static final String TAG = "MainActivity"; + private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite"; + + private ObjectDetector objectDetector; + + private enum InputSource { + UNKNOWN, + IMAGE, + VIDEO, + CAMERA, + } + + private InputSource inputSource = InputSource.UNKNOWN; + + // Image mode demo component. + private ActivityResultLauncher imageGetter; + // Video mode demo component. + private ActivityResultLauncher videoGetter; + private ObjectDetectionResultImageView imageView; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_main); + setupImageModeDemo(); + setupVideoModeDemo(); + // TODO: Adds live camera demo. + } + + /** Sets up the image mode demo. */ + private void setupImageModeDemo() { + imageView = new ObjectDetectionResultImageView(this); + // The Intent to access gallery and read images as bitmap. + imageGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + Bitmap bitmap = null; + int rotation = 0; + try { + bitmap = + downscaleBitmap( + MediaStore.Images.Media.getBitmap( + this.getContentResolver(), resultIntent.getData())); + } catch (IOException e) { + Log.e(TAG, "Bitmap reading error:" + e); + } + try { + InputStream imageData = + this.getContentResolver().openInputStream(resultIntent.getData()); + rotation = getImageRotation(imageData); + } catch (IOException | MediaPipeException e) { + Log.e(TAG, "Bitmap rotation error:" + e); + } + if (bitmap != null) { + MPImage image = new BitmapImageBuilder(bitmap).build(); + ObjectDetectionResult detectionResult = + objectDetector.detect( + image, + ImageProcessingOptions.builder().setRotationDegrees(rotation).build()); + imageView.setData(image, detectionResult); + runOnUiThread(() -> imageView.update()); + } + } + } + }); + Button loadImageButton = findViewById(R.id.button_load_picture); + loadImageButton.setOnClickListener( + v -> { + if (inputSource != InputSource.IMAGE) { + createObjectDetector(RunningMode.IMAGE); + this.inputSource = InputSource.IMAGE; + updateLayout(); + } + // Reads images from gallery. + Intent pickImageIntent = new Intent(Intent.ACTION_PICK); + pickImageIntent.setDataAndType(MediaStore.Images.Media.INTERNAL_CONTENT_URI, "image/*"); + imageGetter.launch(pickImageIntent); + }); + } + + /** Sets up the video mode demo. */ + private void setupVideoModeDemo() { + imageView = new ObjectDetectionResultImageView(this); + // The Intent to access gallery and read a video file. + videoGetter = + registerForActivityResult( + new ActivityResultContracts.StartActivityForResult(), + result -> { + Intent resultIntent = result.getData(); + if (resultIntent != null) { + if (result.getResultCode() == RESULT_OK) { + MediaMetadataRetriever metaRetriever = new MediaMetadataRetriever(); + metaRetriever.setDataSource(this, resultIntent.getData()); + long duration = + Long.parseLong( + metaRetriever.extractMetadata( + MediaMetadataRetriever.METADATA_KEY_DURATION)); + int numFrames = + Integer.parseInt( + metaRetriever.extractMetadata( + MediaMetadataRetriever.METADATA_KEY_VIDEO_FRAME_COUNT)); + long frameIntervalMs = duration / numFrames; + for (int i = 0; i < numFrames; ++i) { + MPImage image = + new BitmapImageBuilder(metaRetriever.getFrameAtIndex(i)).build(); + ObjectDetectionResult detectionResult = + objectDetector.detectForVideo(image, frameIntervalMs * i); + // Currently only annotates the detection result on the first video frame and + // display it to verify the correctness. + // TODO: Annotates the detection result on every frame, save the + // annotated frames as a video file, and play back the video afterwards. + if (i == 0) { + imageView.setData(image, detectionResult); + runOnUiThread(() -> imageView.update()); + } + } + } + } + }); + Button loadVideoButton = findViewById(R.id.button_load_video); + loadVideoButton.setOnClickListener( + v -> { + createObjectDetector(RunningMode.VIDEO); + updateLayout(); + this.inputSource = InputSource.VIDEO; + + // Reads a video from gallery. + Intent pickVideoIntent = new Intent(Intent.ACTION_PICK); + pickVideoIntent.setDataAndType(MediaStore.Video.Media.INTERNAL_CONTENT_URI, "video/*"); + videoGetter.launch(pickVideoIntent); + }); + } + + private void createObjectDetector(RunningMode mode) { + if (objectDetector != null) { + objectDetector.close(); + } + // Initializes a new MediaPipe ObjectDetector instance + ObjectDetectorOptions options = + ObjectDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setScoreThreshold(0.5f) + .setMaxResults(5) + .setRunningMode(mode) + .build(); + objectDetector = ObjectDetector.createFromOptions(this, options); + } + + private void updateLayout() { + // Updates the preview layout. + FrameLayout frameLayout = findViewById(R.id.preview_display_layout); + frameLayout.removeAllViewsInLayout(); + imageView.setImageDrawable(null); + frameLayout.addView(imageView); + imageView.setVisibility(View.VISIBLE); + } + + private Bitmap downscaleBitmap(Bitmap originalBitmap) { + double aspectRatio = (double) originalBitmap.getWidth() / originalBitmap.getHeight(); + int width = imageView.getWidth(); + int height = imageView.getHeight(); + if (((double) imageView.getWidth() / imageView.getHeight()) > aspectRatio) { + width = (int) (height * aspectRatio); + } else { + height = (int) (width / aspectRatio); + } + return Bitmap.createScaledBitmap(originalBitmap, width, height, false); + } + + private int getImageRotation(InputStream imageData) throws IOException, MediaPipeException { + int orientation = + new ExifInterface(imageData) + .getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL); + switch (orientation) { + case ExifInterface.ORIENTATION_NORMAL: + return 0; + case ExifInterface.ORIENTATION_ROTATE_90: + return 90; + case ExifInterface.ORIENTATION_ROTATE_180: + return 180; + case ExifInterface.ORIENTATION_ROTATE_270: + return 270; + default: + // TODO: use getRotationDegrees() and isFlipped() instead of switch once flip + // is supported. + throw new MediaPipeException( + MediaPipeException.StatusCode.UNIMPLEMENTED.ordinal(), + "Flipped images are not supported yet."); + } + } +} diff --git a/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java new file mode 100644 index 000000000..283e48857 --- /dev/null +++ b/mediapipe/tasks/examples/android/objectdetector/src/main/java/com/google/mediapipe/tasks/examples/objectdetector/ObjectDetectionResultImageView.java @@ -0,0 +1,77 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.examples.objectdetector; + +import android.content.Context; +import android.graphics.Bitmap; +import android.graphics.Canvas; +import android.graphics.Color; +import android.graphics.Matrix; +import android.graphics.Paint; +import androidx.appcompat.widget.AppCompatImageView; +import com.google.mediapipe.framework.image.BitmapExtractor; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Detection; +import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetectionResult; + +/** An ImageView implementation for displaying {@link ObjectDetectionResult}. */ +public class ObjectDetectionResultImageView extends AppCompatImageView { + private static final String TAG = "ObjectDetectionResultImageView"; + + private static final int BBOX_COLOR = Color.GREEN; + private static final int BBOX_THICKNESS = 5; // Pixels + private Bitmap latest; + + public ObjectDetectionResultImageView(Context context) { + super(context); + setScaleType(AppCompatImageView.ScaleType.FIT_CENTER); + } + + /** + * Sets a {@link MPImage} and an {@link ObjectDetectionResult} to render. + * + * @param image a {@link MPImage} object for annotation. + * @param result an {@link ObjectDetectionResult} object that contains the detection result. + */ + public void setData(MPImage image, ObjectDetectionResult result) { + if (image == null || result == null) { + return; + } + latest = BitmapExtractor.extract(image); + Canvas canvas = new Canvas(latest); + canvas.drawBitmap(latest, new Matrix(), null); + for (int i = 0; i < result.detections().size(); ++i) { + drawDetectionOnCanvas(result.detections().get(i), canvas); + } + } + + /** Updates the image view with the latest {@link ObjectDetectionResult}. */ + public void update() { + postInvalidate(); + if (latest != null) { + setImageBitmap(latest); + } + } + + private void drawDetectionOnCanvas(Detection detection, Canvas canvas) { + // TODO: Draws the category and the score per bounding box. + // Draws bounding box. + Paint bboxPaint = new Paint(); + bboxPaint.setColor(BBOX_COLOR); + bboxPaint.setStyle(Paint.Style.STROKE); + bboxPaint.setStrokeWidth(BBOX_THICKNESS); + canvas.drawRect(detection.boundingBox(), bboxPaint); + } +} diff --git a/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml b/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml new file mode 100644 index 000000000..c7bd21dbd --- /dev/null +++ b/mediapipe/tasks/examples/android/res/drawable-v24/ic_launcher_foreground.xml @@ -0,0 +1,34 @@ + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml b/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml new file mode 100644 index 000000000..01f0af0ad --- /dev/null +++ b/mediapipe/tasks/examples/android/res/drawable/ic_launcher_background.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/examples/android/res/layout/activity_main.xml b/mediapipe/tasks/examples/android/res/layout/activity_main.xml new file mode 100644 index 000000000..834e9a3e6 --- /dev/null +++ b/mediapipe/tasks/examples/android/res/layout/activity_main.xml @@ -0,0 +1,40 @@ + + + +