This commit is contained in:
Syntax 2022-10-24 22:52:17 +00:00 committed by GitHub
commit 0f4379cd64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
419 changed files with 25985 additions and 2387 deletions

View File

@ -53,7 +53,7 @@ RUN pip3 install wheel
RUN pip3 install future RUN pip3 install future
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1 RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
RUN pip3 install six==1.14.0 RUN pip3 install six==1.14.0
RUN pip3 install tensorflow==2.2.0 RUN pip3 install tensorflow
RUN pip3 install tf_slim RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python

View File

@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
} }
``` ```
## Graph Options
It is possible to specify a "graph options" protobuf for a MediaPipe graph
similar to the [`Calculator Options`](calculators.md#calculator-options)
protobuf specified for a MediaPipe calculator. These "graph options" can be
specified where a graph is invoked, and used to populate calculator options and
subgraph options within the graph.
In a CalculatorGraphConfig, graph options can be specified for a subgraph
exactly like calculator options, as shown below:
```
node {
calculator: "FlowLimiterCalculator"
input_stream: "image"
output_stream: "throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] {
max_in_flight: 1
}
}
}
node {
calculator: "FaceDetectionSubgraph"
input_stream: "IMAGE:throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {
tensor_width: 192
tensor_height: 192
}
}
}
```
In a CalculatorGraphConfig, graph options can be accepted and used to populate
calculator options, as shown below:
```
graph_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {}
}
node: {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:multi_backend_image"
node_options: {
[type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] {
keep_aspect_ratio: true
border_mode: BORDER_ZERO
}
}
option_value: "output_tensor_width:options/tensor_width"
option_value: "output_tensor_height:options/tensor_height"
}
node {
calculator: "InferenceCalculator"
node_options: {
[type.googleapis.com/mediapipe.InferenceCalculatorOptions] {}
}
option_value: "delegate:options/delegate"
option_value: "model_path:options/model_path"
}
```
In this example, the `FaceDetectionSubgraph` accepts graph option protobuf
`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field
values in the calculator options `ImageToTensorCalculatorOptions` and some field
values in the subgraph options `InferenceCalculatorOptions`. The field values
are defined using the `option_value:` syntax.
In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and
`option_value:` together define the option values for a calculator such as
`ImageToTensorCalculator`. The `node_options:` field defines a set of literal
constant values using the text protobuf syntax. Each `option_value:` field
defines the value for one protobuf field using information from the enclosing
graph, specifically from field values of the graph options of the enclosing
graph. In the example above, the `option_value:`
`"output_tensor_width:options/tensor_width"` defines the field
`ImageToTensorCalculatorOptions.output_tensor_width` using the value of
`FaceDetectionOptions.tensor_width`.
The syntax of `option_value:` is similar to the syntax of `input_stream:`. The
syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option
field and the RHS identifies a graph option field. More specifically, the LHS
and RHS each consists of a series of protobuf field names identifying nested
protobuf messages and fields separated by '/'. This is known as the "ProtoPath"
syntax. Nested messages that are referenced in the LHS or RHS must already be
defined in the enclosing protobuf in order to be traversed using
`option_value:`.
## Cycles ## Cycles
<!-- TODO: add discussion of PreviousLoopbackCalculator --> <!-- TODO: add discussion of PreviousLoopbackCalculator -->

View File

@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow
```bash ```bash
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \ GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt --calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt
``` ```
This will open up your webcam as long as it is connected and on. Any errors This will open up your webcam as long as it is connected and on. Any errors

View File

@ -1410,3 +1410,45 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "bypass_calculator_proto",
srcs = ["bypass_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bypass_calculator",
srcs = ["bypass_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":bypass_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "bypass_calculator_test",
srcs = ["bypass_calculator_test.cc"],
deps = [
":bypass_calculator",
":pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,161 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/bypass_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
using mediapipe::BypassCalculatorOptions;
// Defines a "bypass" channel to use in place of a disabled feature subgraph.
// By default, all inputs are discarded and all outputs are ignored.
// Certain input streams can be passed to corresponding output streams
// by specifying them in "pass_input_stream" and "pass_output_stream" options.
// All output streams are updated with timestamp bounds indicating completed
// output.
//
// Note that this calculator is designed for use as a contained_node in a
// SwitchContainer. For this reason, any input and output tags are accepted,
// and stream semantics are specified through BypassCalculatorOptions.
//
// Example config:
// node {
// calculator: "BypassCalculator"
// input_stream: "APPEARANCES:appearances_post_facenet"
// input_stream: "VIDEO:video_frame"
// input_stream: "FEATURE_CONFIG:feature_config"
// input_stream: "ENABLE:gaze_enabled"
// output_stream: "APPEARANCES:analyzed_appearances"
// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
// node_options: {
// [type.googleapis.com/mediapipe.BypassCalculatorOptions] {
// pass_input_stream: "APPEARANCES"
// pass_output_stream: "APPEARANCES"
// }
// }
// }
//
class BypassCalculator : public Node {
public:
static constexpr mediapipe::api2::Input<int>::Optional kNotNeeded{"N_N_"};
MEDIAPIPE_NODE_CONTRACT(kNotNeeded);
using IdMap = std::map<CollectionItemId, CollectionItemId>;
// Returns the map of passthrough input and output stream ids.
static absl::StatusOr<IdMap> GetPassMap(
const BypassCalculatorOptions& options, const tool::TagMap& input_map,
const tool::TagMap& output_map) {
IdMap result;
auto& input_streams = options.pass_input_stream();
auto& output_streams = options.pass_output_stream();
int size = std::min(input_streams.size(), output_streams.size());
for (int i = 0; i < size; ++i) {
std::pair<std::string, int> in_tag, out_tag;
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i),
&in_tag.first, &in_tag.second));
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i),
&out_tag.first, &out_tag.second));
auto input_id = input_map.GetId(in_tag.first, in_tag.second);
auto output_id = output_map.GetId(out_tag.first, out_tag.second);
result[input_id] = output_id;
}
return result;
}
// Identifies all specified streams as "Any" packet type.
// Identifies passthrough streams as "Same" packet type.
static absl::Status UpdateContract(CalculatorContract* cc) {
auto options = cc->Options<BypassCalculatorOptions>();
RET_CHECK_EQ(options.pass_input_stream().size(),
options.pass_output_stream().size());
ASSIGN_OR_RETURN(
auto pass_streams,
GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap()));
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams) {
pass_out.insert(entry.second);
cc->Inputs().Get(entry.first).SetAny();
cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first));
}
for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) {
if (pass_streams.count(id) == 0) {
cc->Inputs().Get(id).SetAny();
}
}
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetAny();
}
}
return absl::OkStatus();
}
// Saves the map of passthrough input and output stream ids.
absl::Status Open(CalculatorContext* cc) override {
auto options = cc->Options<BypassCalculatorOptions>();
ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(),
*cc->Outputs().TagMap()));
return absl::OkStatus();
}
// Copies packets between passthrough input and output streams.
// Updates timestamp bounds on all output streams.
absl::Status Process(CalculatorContext* cc) override {
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams_) {
pass_out.insert(entry.second);
auto& packet = cc->Inputs().Get(entry.first).Value();
if (packet.Timestamp() == cc->InputTimestamp()) {
cc->Outputs().Get(entry.first).AddPacket(packet);
}
}
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetNextTimestampBound(
std::max(cc->Outputs().Get(id).NextTimestampBound(), bound));
}
}
return absl::OkStatus();
}
// Close all output streams.
absl::Status Close(CalculatorContext* cc) override {
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).Close();
}
return absl::OkStatus();
}
private:
IdMap pass_streams_;
};
MEDIAPIPE_REGISTER_NODE(BypassCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BypassCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional BypassCalculatorOptions ext = 481259677;
}
// Names an input stream or streams to pass through, by "TAG:index".
repeated string pass_input_stream = 1;
// Names an output stream or streams to pass through, by "TAG:index".
repeated string pass_output_stream = 2;
}

View File

@ -0,0 +1,302 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// A graph with using a BypassCalculator to pass through and ignore
// most of its inputs and outputs.
constexpr char kTestGraphConfig1[] = R"pb(
type: "AppearancesPassThroughSubgraph"
input_stream: "APPEARANCES:appearances"
input_stream: "VIDEO:video_frame"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "APPEARANCES:passthrough_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output"
node {
calculator: "BypassCalculator"
input_stream: "PASS:appearances"
input_stream: "TRUNCATE:0:video_frame"
input_stream: "TRUNCATE:1:feature_config"
output_stream: "PASS:passthrough_appearances"
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "PASS"
pass_output_stream: "PASS"
}
}
}
)pb";
// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig2[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
}
}
}
)pb";
// A graph with using BypassCalculator as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig3[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: {
calculator: "BypassCalculator"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "APPEARANCES"
pass_output_stream: "APPEARANCES"
}
}
}
}
}
}
)pb";
// A graph with using BypassCalculator as a disabled-gate
// for input frames and appearances.
constexpr char kTestGraphConfig4[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "ENABLE:gaze_enabled"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "VIDEO:video_frame_out"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEATURE_CONFIG:feature_config_out"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "BypassCalculator" }
contained_node: { calculator: "PassThroughCalculator" }
}
}
}
)pb";
// Reports packet timestamp and string contents, or "<empty>"".
std::string DebugString(Packet p) {
return absl::StrCat(p.Timestamp().DebugString(), ":",
p.IsEmpty() ? "<empty>" : p.Get<std::string>());
}
// Shows a bypass subgraph that passes through one stream.
TEST(BypassCalculatorTest, SubgraphChannel) {
CalculatorGraphConfig config_1 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig1);
CalculatorGraphConfig config_2 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig2);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that passes through one stream.
TEST(BypassCalculatorTest, CalculatorChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig3);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that discards all inputs when ENABLED is false.
TEST(BypassCalculatorTest, GatedChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig4);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> video_frame;
MP_ASSERT_OK(graph.ObserveOutputStream(
"video_frame_out",
[&](const Packet& p) {
video_frame.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
// Close the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(false).At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 200.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Only timestamps arrive from the BypassCalculator.
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:<empty>"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>"));
// Open the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(true).At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 300.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f2").At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Packets arrive from the PassThroughCalculator.
EXPECT_THAT(analyzed_appearances,
testing::ElementsAre("200:<empty>", "300:a2"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>", "300:v2"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -209,11 +209,18 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "rotation_mode_proto",
srcs = ["rotation_mode.proto"],
visibility = ["//visibility:public"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "image_transformation_calculator_proto", name = "image_transformation_calculator_proto",
srcs = ["image_transformation_calculator.proto"], srcs = ["image_transformation_calculator.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":rotation_mode_proto",
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto", "//mediapipe/gpu:scale_mode_proto",
@ -238,6 +245,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":rotation_mode_cc_proto",
":image_transformation_calculator_cc_proto", ":image_transformation_calculator_cc_proto",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h" #include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
#include "mediapipe/calculators/image/rotation_mode.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"

View File

@ -16,20 +16,10 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/calculators/image/rotation_mode.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/scale_mode.proto"; import "mediapipe/gpu/scale_mode.proto";
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}
message ImageTransformationCalculatorOptions { message ImageTransformationCalculatorOptions {
extend CalculatorOptions { extend CalculatorOptions {
optional ImageTransformationCalculatorOptions ext = 251952830; optional ImageTransformationCalculatorOptions ext = 251952830;

View File

@ -0,0 +1,31 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
option java_package = "com.google.mediapipe.calculator.proto";
option java_outer_classname = "RotationModeProto";
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}

View File

@ -161,6 +161,193 @@ cc_test(
], ],
) )
mediapipe_proto_library(
name = "bert_preprocessor_calculator_proto",
srcs = ["bert_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bert_preprocessor_calculator",
srcs = ["bert_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":bert_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "bert_preprocessor_calculator_test",
srcs = ["bert_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":bert_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library(
name = "regex_preprocessor_calculator_proto",
srcs = ["regex_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "regex_preprocessor_calculator",
srcs = ["regex_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":regex_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "regex_preprocessor_calculator_test",
srcs = ["regex_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":regex_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:sink",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "text_to_tensor_calculator_test",
srcs = ["text_to_tensor_calculator_test.cc"],
deps = [
":text_to_tensor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:options_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "universal_sentence_encoder_preprocessor_calculator",
srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "universal_sentence_encoder_preprocessor_calculator_test",
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
deps = [
":universal_sentence_encoder_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],
@ -320,6 +507,8 @@ cc_library(
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
], ],
) )

View File

@ -0,0 +1,251 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kTokenizerProcessUnitIndex = 0;
constexpr absl::string_view kInputIdsTensorName = "ids";
constexpr absl::string_view kInputMasksTensorName = "mask";
constexpr absl::string_view kSegmentIdsTensorName = "segment_ids";
constexpr absl::string_view kClassifierToken = "[CLS]";
constexpr absl::string_view kSeparatorToken = "[SEP]";
// Preprocesses input text into three int32 input tensors for a BERT model using
// a tokenizer.
// The associated BERT model is expected to contain input tensors with names:
//
// Tensor | Metadata Name
// ---------------- | --------------
// IDs | "ids"
// Segment IDs | "segment_ids"
// Mask | "mask"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have names corresponding to the above
// metadata names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// This calculator is currently configured for the TextClassifier Task but it
// will eventually be generalized for other Text Tasks.
// TODO: Handle preprocessing for other Text Tasks too.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the BERT model. Used to determine the order of
// the three input Tensors for the BERT model and to extract the metadata to
// construct the tokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the BERT model:
// (1): the token ids of the tokenized input string. A classifier token
// ("[CLS]") will be prepended to the input tokens and a separator
// token ("[SEP]") will be appended to the input tokens.
// (2): the segment ids, which are all 0 for now but will have different
// values to distinguish between different sentences in the input
// text for other Text tasks.
// (3): the input mask ids, which are 1 at each of the input token indices
// and 0 elsewhere.
// The Tensors will have size equal to the max sequence length for the BERT
// model.
//
// Example:
// node {
// calculator: "BertPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.BertPreprocessorCalculatorOptions.ext] {
// bert_max_seq_len: 128
// }
// }
// }
class BertPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
// The max sequence length accepted by the BERT model.
int bert_max_seq_len_ = 2;
// Indices of the three input tensors for the BERT model. They should form the
// set {0, 1, 2}.
int input_ids_tensor_index_ = 0;
int segment_ids_tensor_index_ = 1;
int input_masks_tensor_index_ = 2;
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
// Processes the `input_tokens` to generate the three input tensors for the
// BERT model.
std::vector<Tensor> GenerateInputTensors(
const std::vector<std::string>& input_tokens);
};
absl::Status BertPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
RET_CHECK_GE(options.bert_max_seq_len(), 2)
<< "bert_max_seq_len must be at least 2";
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::ProcessUnit* tokenizer_metadata =
metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex);
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateTokenizerFromProcessUnit(
tokenizer_metadata, metadata_extractor));
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
input_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputIdsTensorName);
segment_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kSegmentIdsTensorName);
input_masks_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputMasksTensorName);
absl::flat_hash_set<int> tensor_indices = {input_ids_tensor_index_,
segment_ids_tensor_index_,
input_masks_tensor_index_};
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
input_ids_tensor_index_, segment_ids_tensor_index_,
input_masks_tensor_index_));
}
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
bert_max_seq_len_ = options.bert_max_seq_len();
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
kTensorsOut(cc).Send(
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
return absl::OkStatus();
}
std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
absl::string_view input_text) {
std::string processed_input = std::string(input_text);
absl::AsciiStrToLower(&processed_input);
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(processed_input);
// Offset by 2 to account for [CLS] and [SEP]
int input_tokens_size =
std::min(bert_max_seq_len_,
static_cast<int>(tokenizer_result.subwords.size()) + 2);
std::vector<std::string> input_tokens;
input_tokens.reserve(input_tokens_size);
input_tokens.push_back(std::string(kClassifierToken));
for (int i = 0; i < input_tokens_size - 2; ++i) {
input_tokens.push_back(std::move(tokenizer_result.subwords[i]));
}
input_tokens.push_back(std::string(kSeparatorToken));
return input_tokens;
}
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
const std::vector<std::string>& input_tokens) {
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
// Convert tokens back into ids and set mask
for (int i = 0; i < input_tokens.size(); ++i) {
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_masks[i] = 1;
}
// |<--------bert_max_seq_len_--------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForBert);
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
}
std::memcpy(input_tensors[input_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_ids.data(), input_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[segment_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
segment_ids.data(), segment_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[input_masks_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_masks.data(), input_masks.size() * sizeof(int32_t));
return input_tensors;
}
MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BertPreprocessorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional BertPreprocessorCalculatorOptions ext = 462509271;
}
// The maximum input sequence length for the calculator's BERT model.
optional int32 bert_max_seq_len = 1;
}

View File

@ -0,0 +1,154 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kBertMaxSeqLen = 128;
constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/bert_text_classifier.tflite";
absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
absl::string_view text, absl::string_view model_path) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "BertPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
options {
[mediapipe.BertPreprocessorCalculatorOptions.ext] {
bert_max_seq_len: $0
}
}
}
)",
kBertMaxSeqLen));
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer = tasks::core::LoadBinaryContent(model_path.data());
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != kNumInputTensorsForBert) {
return absl::InvalidArgumentError(
absl::Substitute("tensor_vec has size $0, expected $1",
tensor_vec.size(), kNumInputTensorsForBert));
}
std::vector<std::vector<int>> results;
for (int i = 0; i < kNumInputTensorsForBert; i++) {
const Tensor& tensor = tensor_vec[i];
if (tensor.element_type() != Tensor::ElementType::kInt32) {
return absl::InvalidArgumentError("Expected tensor element type kInt32");
}
auto* buffer = tensor.GetCpuReadView().buffer<int>();
std::vector<int> buffer_view(buffer, buffer + kBertMaxSeqLen);
results.push_back(buffer_view);
}
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
return results;
}
TEST(BertPreprocessorCalculatorTest, TextClassifierWithBertModel) {
std::vector<std::vector<int>> expected_result = {
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 102}};
// segment_ids
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
// input_masks
expected_result.push_back(std::vector(expected_result[0].size(), 1));
expected_result[2].resize(kBertMaxSeqLen);
// padding input_ids
expected_result[0].resize(kBertMaxSeqLen);
MP_ASSERT_OK_AND_ASSIGN(
std::vector<std::vector<int>> processed_tensor_values,
RunBertPreprocessorCalculator(
"it's a charming and often affecting journey", kTestModelPath));
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
TEST(BertPreprocessorCalculatorTest, LongInput) {
std::stringstream long_input;
long_input
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kBertMaxSeqLen; ++i) {
long_input << " long";
}
long_input << " movie review";
std::vector<std::vector<int>> expected_result = {
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1998, 2023,
2003, 1037}};
// "long" id
expected_result[0].resize(kBertMaxSeqLen - 1, 2146);
// "[SEP]" id
expected_result[0].push_back(102);
// segment_ids
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
// input_masks
expected_result.push_back(std::vector(kBertMaxSeqLen, 1));
MP_ASSERT_OK_AND_ASSIGN(
std::vector<std::vector<int>> processed_tensor_values,
RunBertPreprocessorCalculator(long_input.str(), kTestModelPath));
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
} // namespace
} // namespace mediapipe

View File

@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node {
} }
ASSIGN_OR_RETURN(auto image, GetInputImage(cc)); ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
const Size size{image->width(), image->height()};
RotatedRect roi = GetRoi(size.width, size.height, norm_rect); RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(), ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
options_.output_tensor_height(), options_.output_tensor_height(),
options_.keep_aspect_ratio(), &roi)); options_.keep_aspect_ratio(), &roi));
@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node {
} }
if (kOutMatrix(cc).IsConnected()) { if (kOutMatrix(cc).IsConnected()) {
std::array<float, 16> matrix; std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height, GetRotatedSubRectToRectTransformMatrix(
/*flip_horizontaly=*/false, roi, image->width(), image->height(),
&matrix); /*flip_horizontaly=*/false, &matrix);
kOutMatrix(cc).Send(std::move(matrix)); kOutMatrix(cc).Send(std::move(matrix));
} }
// Lazy initialization of the GPU or CPU converter. // Lazy initialization of the GPU or CPU converter.
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get())); MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
ASSIGN_OR_RETURN(Tensor tensor, Tensor::ElementType output_tensor_type =
(image->UsesGpu() ? gpu_converter_ : cpu_converter_) GetOutputTensorType(image->UsesGpu());
->Convert(*image, roi, {output_width_, output_height_}, Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
range_min_, range_max_)); GetNumOutputChannels(*image)});
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, range_min_, range_max_,
/*tensor_buffer_offset=*/0, tensor));
auto result = std::make_unique<std::vector<Tensor>>(); auto result = std::make_unique<std::vector<Tensor>>();
result->push_back(std::move(tensor)); result->push_back(std::move(tensor));
@ -292,15 +295,31 @@ class ImageToTensorCalculator : public Node {
} }
} }
Tensor::ElementType GetOutputTensorType() { Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
if (is_float_output_) { if (!uses_gpu) {
return Tensor::ElementType::kFloat32; if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
} }
if (range_min_ < 0) { // Always use float32 when GPU is enabled.
return Tensor::ElementType::kInt8; return Tensor::ElementType::kFloat32;
} else { }
return Tensor::ElementType::kUInt8;
int GetNumOutputChannels(const Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
} }
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
} }
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage( absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node {
#if !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
cpu_converter_, cpu_converter_,
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType())); CreateOpenCvConverter(cc, GetBorderMode(),
GetOutputTensorType(/*uses_gpu=*/false)));
#else #else
LOG(FATAL) << "Cannot create image to tensor opencv converter since " LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined."; "MEDIAPIPE_DISABLE_OPENCV is defined.";

View File

@ -42,13 +42,16 @@ class ImageToTensorConverter {
// @image contains image to extract from. // @image contains image to extract from.
// @roi describes region of interest within the image to extract (absolute // @roi describes region of interest within the image to extract (absolute
// values). // values).
// @output_dims dimensions of output tensor.
// @range_min/max describes output tensor range image pixels should converted // @range_min/max describes output tensor range image pixels should converted
// to. // to.
virtual absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, // @tensor_buffer_offset an inteter representing the offset of the tensor
const RotatedRect& roi, // buffer the result should be written to.
const Size& output_dims, // @output_tensor a tensor with pre-defined shape. The "Convert" is
float range_min, float range_max) = 0; // responsible of populating the content into the output tensor.
virtual absl::Status Convert(const mediapipe::Image& input,
const RotatedRect& roi, float range_min,
float range_max, int tensor_buffer_offset,
Tensor& output_tensor) = 0;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -264,57 +264,58 @@ class GlProcessor : public ImageToTensorConverter {
}); });
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat( return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ", "Unsupported format: ", static_cast<uint32_t>(input.format())));
static_cast<uint32_t>(input.format())));
} }
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3; MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
Tensor tensor(Tensor::ElementType::kFloat32, [this, &output_tensor, &input, &roi, &output_shape, range_min,
{1, output_dims.height, output_dims.width, kNumChannels}); range_max, tensor_buffer_offset]() -> absl::Status {
const int input_num_channels = input.channels();
auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(),
input_num_channels == 4 ? GL_RGB : GL_RGBA,
source_texture.width() * source_texture.height() *
input_num_channels * sizeof(uint8_t),
/*layer=*/0,
/*owned=*/false);
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi, constexpr float kInputImageRangeMin = 0.0f;
&output_dims, range_min, constexpr float kInputImageRangeMax = 1.0f;
range_max]() -> absl::Status { ASSIGN_OR_RETURN(auto transform,
constexpr int kRgbaNumChannels = 4; GetValueRangeTransformation(kInputImageRangeMin,
auto source_texture = gl_helper_.CreateSourceTexture(input); kInputImageRangeMax,
tflite::gpu::gl::GlTexture input_texture( range_min, range_max));
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
source_texture.width() * source_texture.height() * kRgbaNumChannels *
sizeof(uint8_t),
/*layer=*/0,
/*owned=*/false);
constexpr float kInputImageRangeMin = 0.0f; const int output_size = output_tensor.bytes() / output_shape.dims[0];
constexpr float kInputImageRangeMax = 1.0f; auto buffer_view = output_tensor.GetOpenGlBufferWriteView();
ASSIGN_OR_RETURN( tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
auto transform, buffer_view.name(), output_size,
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, /*offset=*/tensor_buffer_offset,
range_min, range_max)); /*has_ownership=*/false);
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()),
roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_queue_.get(), &output));
auto buffer_view = tensor.GetOpenGlBufferWriteView(); return absl::OkStatus();
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER, }));
buffer_view.name(), tensor.bytes(),
/*offset=*/0,
/*has_ownership=*/false);
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()), roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width),
command_queue_.get(), &output));
return absl::OkStatus(); return absl::OkStatus();
}));
return tensor;
} }
~GlProcessor() override { ~GlProcessor() override {
@ -326,6 +327,17 @@ class GlProcessor : public ImageToTensorConverter {
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_; std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
std::unique_ptr<SubRectExtractorGl> extractor_; std::unique_ptr<SubRectExtractorGl> extractor_;
mediapipe::GlCalculatorHelper gl_helper_; mediapipe::GlCalculatorHelper gl_helper_;

View File

@ -168,26 +168,26 @@ class GlProcessor : public ImageToTensorConverter {
}); });
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat( return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ", "Unsupported format: ", static_cast<uint32_t>(input.format())));
static_cast<uint32_t>(input.format())));
} }
// TODO: support tensor_buffer_offset > 0 scenario.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3; MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
Tensor tensor( [this, &output_tensor, &input, &roi, &output_shape, range_min,
Tensor::ElementType::kFloat32, range_max]() -> absl::Status {
Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels});
MP_RETURN_IF_ERROR(
gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims,
range_min, range_max]() -> absl::Status {
auto input_texture = gl_helper_.CreateSourceTexture(input); auto input_texture = gl_helper_.CreateSourceTexture(input);
constexpr float kInputImageRangeMin = 0.0f; constexpr float kInputImageRangeMin = 0.0f;
@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin, GetValueRangeTransformation(kInputImageRangeMin,
kInputImageRangeMax, kInputImageRangeMax,
range_min, range_max)); range_min, range_max));
auto tensor_view = tensor.GetOpenGlTexture2dWriteView(); auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi, MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
/*flip_horizontaly=*/false, /*flip_horizontaly=*/false,
transform.scale, transform.offset, transform.scale, transform.offset,
output_dims, &tensor_view)); output_shape, &tensor_view));
return absl::OkStatus(); return absl::OkStatus();
})); }));
return tensor; return absl::OkStatus();
} }
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture, absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
const RotatedRect& sub_rect, const RotatedRect& sub_rect,
bool flip_horizontaly, float alpha, float beta, bool flip_horizontaly, float alpha, float beta,
const Size& output_dims, const Tensor::Shape& output_shape,
Tensor::OpenGlTexture2dView* output) { Tensor::OpenGlTexture2dView* output) {
const int output_height = output_shape.dims[1];
const int output_width = output_shape.dims[2];
std::array<float, 16> transform_mat; std::array<float, 16> transform_mat;
glDisable(GL_DEPTH_TEST); glDisable(GL_DEPTH_TEST);
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
glViewport(0, 0, output_dims.width, output_dims.height); glViewport(0, 0, output_width, output_height);
glActiveTexture(GL_TEXTURE0); glActiveTexture(GL_TEXTURE0);
glBindTexture(GL_TEXTURE_2D, output->name()); glBindTexture(GL_TEXTURE_2D, output->name());
@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter {
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
mediapipe::GlCalculatorHelper gl_helper_; mediapipe::GlCalculatorHelper gl_helper_;
bool use_custom_zero_border_ = false; bool use_custom_zero_border_ = false;
BorderMode border_mode_ = BorderMode::kReplicate; BorderMode border_mode_ = BorderMode::kReplicate;

View File

@ -262,7 +262,6 @@ class SubRectExtractorMetal {
RET_CHECK(pipeline_state != nil); RET_CHECK(pipeline_state != nil);
std::string output_type_def; std::string output_type_def;
MTLPixelFormat pixel_format;
switch (output_format) { switch (output_format) {
case OutputFormat::kF16C4: case OutputFormat::kF16C4:
output_type_def = R"( output_type_def = R"(
@ -348,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 && if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 && input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) { input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -359,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ", "Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format()))); static_cast<uint32_t>(input.format())));
} }
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
@autoreleasepool { @autoreleasepool {
id<MTLTexture> texture = id<MTLTexture> texture =
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()]; [metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
constexpr int kNumChannels = 4;
Tensor tensor(Tensor::ElementType::kFloat32,
Tensor::Shape{1, output_dims.height, output_dims.width,
kNumChannels});
constexpr float kInputImageRangeMin = 0.0f; constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f; constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
@ -377,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter {
range_min, range_max)); range_min, range_max));
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer]; id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer); const auto& buffer_view =
output_tensor.GetMtlBufferWriteView(command_buffer);
MP_RETURN_IF_ERROR(extractor_->Execute( MP_RETURN_IF_ERROR(extractor_->Execute(
texture, roi, texture, roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset, /*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width), tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_buffer, buffer_view.buffer())); command_buffer, buffer_view.buffer()));
[command_buffer commit]; [command_buffer commit];
return tensor; return absl::OkStatus();
} }
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 4)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
MPPMetalHelper* metal_helper_ = nil; MPPMetalHelper* metal_helper_ = nil;
std::unique_ptr<SubRectExtractorMetal> extractor_; std::unique_ptr<SubRectExtractorMetal> extractor_;
}; };

View File

@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter {
} }
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
const RotatedRect& roi, float range_min, float range_max,
const Size& output_dims, float range_min, int tensor_buffer_offset,
float range_max) override { Tensor& output_tensor) override {
if (input.image_format() != mediapipe::ImageFormat::SRGB && if (input.image_format() != mediapipe::ImageFormat::SRGB &&
input.image_format() != mediapipe::ImageFormat::SRGBA) { input.image_format() != mediapipe::ImageFormat::SRGBA) {
return InvalidArgumentError( return InvalidArgumentError(
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
static_cast<uint32_t>(input.image_format()))); static_cast<uint32_t>(input.image_format())));
} }
auto src = mediapipe::formats::MatView(&input); // TODO: Remove the check once tensor_buffer_offset > 0 is
// supported.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3; const int output_height = output_shape.dims[1];
Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height, const int output_width = output_shape.dims[2];
output_dims.width, kNumChannels}); const int output_channels = output_shape.dims[3];
auto buffer_view = tensor.GetCpuWriteView(); auto buffer_view = output_tensor.GetCpuWriteView();
cv::Mat dst; cv::Mat dst;
switch (tensor_type_) { switch (tensor_type_) {
case Tensor::ElementType::kInt8: case Tensor::ElementType::kInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<int8>()); buffer_view.buffer<int8>());
break; break;
case Tensor::ElementType::kFloat32: case Tensor::ElementType::kFloat32:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<float>()); buffer_view.buffer<float>());
break; break;
case Tensor::ElementType::kUInt8: case Tensor::ElementType::kUInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<uint8>()); buffer_view.buffer<uint8>());
break; break;
default: default:
@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
cv::Mat src_points; cv::Mat src_points;
cv::boxPoints(rotated_rect, src_points); cv::boxPoints(rotated_rect, src_points);
const float dst_width = output_dims.width; const float dst_width = output_width;
const float dst_height = output_dims.height; const float dst_height = output_height;
/* clang-format off */ /* clang-format off */
float dst_corners[8] = {0.0f, dst_height, float dst_corners[8] = {0.0f, dst_height,
0.0f, 0.0f, 0.0f, 0.0f,
@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
dst_width, dst_height}; dst_width, dst_height};
/* clang-format on */ /* clang-format on */
auto src = mediapipe::formats::MatView(&input);
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners); cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
cv::Mat projection_matrix = cv::Mat projection_matrix =
cv::getPerspectiveTransform(src_points, dst_points); cv::getPerspectiveTransform(src_points, dst_points);
@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
/*flags=*/cv::INTER_LINEAR, /*flags=*/cv::INTER_LINEAR,
/*borderMode=*/border_mode_); /*borderMode=*/border_mode_);
if (transformed.channels() > kNumChannels) { if (transformed.channels() > output_channels) {
cv::Mat proper_channels_mat; cv::Mat proper_channels_mat;
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB); cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
transformed = proper_channels_mat; transformed = proper_channels_mat;
@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
range_min, range_max)); range_min, range_max));
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset); transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
return tensor; return absl::OkStatus();
} }
private: private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
enum cv::BorderTypes border_mode_; enum cv::BorderTypes border_mode_;
Tensor::ElementType tensor_type_; Tensor::ElementType tensor_type_;
int mat_type_; int mat_type_;

View File

@ -26,6 +26,8 @@
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors, CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
std::vector<Tensor>& output_tensors) { std::vector<Tensor>& output_tensors) {
return gpu_helper_.RunInGlContext( return gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status { [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
// Explicitly copy input. // Explicitly copy input.
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
glBindBuffer(GL_COPY_READ_BUFFER, glBindBuffer(GL_COPY_READ_BUFFER,
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
} }
// Run inference. // Run inference.
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
output_tensors.reserve(output_size_); output_tensors.reserve(output_size_);
for (int i = 0; i < output_size_; ++i) { for (int i = 0; i < output_size_; ++i) {

View File

@ -32,6 +32,8 @@
#include "mediapipe/util/android/file/base/helpers.h" #include "mediapipe/util/android/file/base/helpers.h"
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
const mediapipe::InferenceCalculatorOptions::Delegate& delegate); const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
absl::StatusOr<std::vector<Tensor>> Process( absl::StatusOr<std::vector<Tensor>> Process(
const std::vector<Tensor>& input_tensors); CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
absl::Status Close(); absl::Status Close();
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
absl::StatusOr<std::vector<Tensor>> absl::StatusOr<std::vector<Tensor>>
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
const std::vector<Tensor>& input_tensors) { CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
std::vector<Tensor> output_tensors; std::vector<Tensor> output_tensors;
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status { [this, cc, &input_tensors, &output_tensors]() -> absl::Status {
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].GetOpenGlBufferReadView().name(), i)); input_tensors[i].GetOpenGlBufferReadView().name(), i));
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
output_tensors.back().GetOpenGlBufferWriteView().name(), i)); output_tensors.back().GetOpenGlBufferWriteView().name(), i));
} }
// Run inference. // Run inference.
return tflite_gpu_runner_->Invoke(); {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
return tflite_gpu_runner_->Invoke();
}
})); }));
return output_tensors; return output_tensors;
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
auto output_tensors = absl::make_unique<std::vector<Tensor>>(); auto output_tensors = absl::make_unique<std::vector<Tensor>>();
ASSIGN_OR_RETURN(*output_tensors, ASSIGN_OR_RETURN(*output_tensors,
gpu_inference_runner_->Process(input_tensors)); gpu_inference_runner_->Process(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors)); kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus(); return absl::OkStatus();

View File

@ -224,9 +224,6 @@ absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
void InferenceCalculatorMetalImpl::AddDelegate( void InferenceCalculatorMetalImpl::AddDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) { CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>();
// Configure and create the delegate. // Configure and create the delegate.
TFLGpuDelegateOptions options; TFLGpuDelegateOptions options;
// `enable_quantization` enables the run of sparse models i.e. the models with // `enable_quantization` enables the run of sparse models i.e. the models with

View File

@ -21,8 +21,10 @@
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe { namespace mediapipe {
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes()); std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
} }
template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>();
tflite::DynamicBuffer dynamic_buffer;
dynamic_buffer.AddString(input_tensor_buffer,
input_tensor.shape().num_elements());
dynamic_buffer.WriteToTensorAsVector(
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
}
template <typename T> template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter, void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
int output_tensor_index, int output_tensor_index,
@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
break; break;
} }
case TfLiteType::kTfLiteUInt8: { case TfLiteType::kTfLiteUInt8: {
CopyTensorBufferToInterpreter<uint8>(input_tensors[i], CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteInt8: { case TfLiteType::kTfLiteInt8: {
CopyTensorBufferToInterpreter<int8>(input_tensors[i], CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteInt32: { case TfLiteType::kTfLiteInt32: {
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
interpreter_.get(), i); interpreter_.get(), i);
break; break;
} }
case TfLiteType::kTfLiteString: {
CopyTensorBufferToInterpreter<char>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteBool:
// No current use-case for copying MediaPipe Tensors with bool type to
// TfLiteTensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type)); absl::StrCat("Unsupported input tensor type:", input_tensor_type));
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i, CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
&output_tensors.back()); &output_tensors.back());
break; break;
case TfLiteType::kTfLiteBool:
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
Tensor::QuantizationParameters{1.0f, 0});
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteString:
// No current use-case for copying TfLiteTensors with string type to
// MediaPipe Tensors.
default: default:
return absl::InvalidArgumentError( return absl::InvalidArgumentError(
absl::StrCat("Unsupported output tensor type:", absl::StrCat("Unsupported output tensor type:",

View File

@ -0,0 +1,174 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
// Preprocesses input text into one int32 input tensor for a text model using
// a RegexTokenizer.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the text model. Used to extract the metadata
// to construct the RegexTokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor which is the text model's input tensor.
// Depending on the tokenizer metadata, the tensor may start with
// the id of the tokenizer's <START> token. The following tensor values will
// be the ids of the tokens of the input text. Any out-of-vocab tokens will
// have the id of the <UNKNOWN> token. The tensor will be padded with the
// <PAD> token id to have size equal to the max sequence length for the text
// model.
//
// Example:
// node {
// calculator: "RegexPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.RegexPreprocessorCalculatorOptions.ext] {
// max_seq_len: 256
// }
// }
// }
class RegexPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::RegexTokenizer> tokenizer_;
// The max sequence length accepted by the text model.
int max_seq_len_ = 0;
};
absl::Status RegexPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required";
RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive";
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::TensorMetadata* tensor_metadata =
metadata_extractor->GetInputTensorMetadata(0);
if (tensor_metadata == nullptr) {
return absl::InvalidArgumentError("No tensor metadata found");
}
ASSIGN_OR_RETURN(
const auto* tokenizer_metadata,
metadata_extractor->FindFirstProcessUnit(
*tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions));
if (tokenizer_metadata == nullptr) {
return absl::InvalidArgumentError("No tokenizer metadata found");
}
const tflite::RegexTokenizerOptions* regex_tokenizer_options =
tokenizer_metadata->options_as<tflite::RegexTokenizerOptions>();
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateRegexTokenizerFromOptions(
regex_tokenizer_options, metadata_extractor));
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
max_seq_len_ = options.max_seq_len();
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(kTextIn(cc).Get());
int unknown_token_id = 0;
tokenizer_->GetUnknownToken(&unknown_token_id);
int pad_token_id = 0;
tokenizer_->GetPadToken(&pad_token_id);
std::vector<int> input_tokens(max_seq_len_, pad_token_id);
int start_token_id = 0;
int input_token_index = 0;
if (tokenizer_->GetStartToken(&start_token_id)) {
input_tokens[0] = start_token_id;
input_token_index = 1;
}
for (int i = 0; (i < tokenizer_result.subwords.size()) &&
(input_token_index < max_seq_len_);
++i, ++input_token_index) {
const std::string& token = tokenizer_result.subwords[i];
int token_id = 0;
if (tokenizer_->LookupId(token, &token_id)) {
input_tokens[input_token_index] = token_id;
} else {
input_tokens[input_token_index] = unknown_token_id;
}
}
// |<-------sentence_length-------->|
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's
// not found in the tokenizer vocab.
std::vector<Tensor> result;
result.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
kTensorsOut(cc).Send(std::move(result));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message RegexPreprocessorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional RegexPreprocessorCalculatorOptions ext = 463716697;
}
// The maximum input sequence length for the calculator's text model.
optional int32 max_seq_len = 1;
}

View File

@ -0,0 +1,130 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/sink.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kMaxSeqLen = 256;
constexpr char kTestModelPath[] =
"mediapipe/tasks/testdata/text/"
"test_model_text_classifier_with_regex_tokenizer.tflite";
absl::StatusOr<std::vector<int>> RunRegexPreprocessorCalculator(
absl::string_view text) {
auto graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "RegexPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
options {
[mediapipe.RegexPreprocessorCalculatorOptions.ext] {
max_seq_len: $0
}
}
}
)pb",
kMaxSeqLen));
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer = tasks::core::LoadBinaryContent(kTestModelPath);
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected $1", tensor_vec.size(), 1));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kInt32) {
return absl::InvalidArgumentError("Expected tensor element type kInt32");
}
auto* buffer = tensor_vec[0].GetCpuReadView().buffer<int>();
std::vector<int> result(buffer, buffer + kMaxSeqLen);
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
return result;
}
TEST(RegexPreprocessorCalculatorTest, TextClassifierModel) {
MP_ASSERT_OK_AND_ASSIGN(
std::vector<int> processed_tensor_values,
RunRegexPreprocessorCalculator("This is the best movie Ive seen in "
"recent years. Strongly recommend it!"));
static const int expected_result[kMaxSeqLen] = {
1, 2, 9, 4, 118, 20, 2, 2, 110, 11, 1136, 153, 2, 386, 12};
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
TEST(RegexPreprocessorCalculatorTest, LongInput) {
std::stringstream long_input;
long_input << "This is the best";
for (int i = 0; i < kMaxSeqLen; ++i) {
long_input << " best";
}
long_input << "movie Ive seen in recent years. Strongly recommend it!";
MP_ASSERT_OK_AND_ASSIGN(std::vector<int> processed_tensor_values,
RunRegexPreprocessorCalculator(long_input.str()));
std::vector<int> expected_result = {1, 2, 9, 4, 118};
// "best" id
expected_result.resize(kMaxSeqLen, 118);
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
} // namespace
} // namespace mediapipe

View File

@ -296,7 +296,6 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) {
output_tensors->emplace_back(Tensor::ElementType::kFloat32, output_tensors->emplace_back(Tensor::ElementType::kFloat32,
Tensor::Shape{1, height, width, channels}); Tensor::Shape{1, height, width, channels});
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer]; id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TensorConverterCalculatorConvert"; command_buffer.label = @"TensorConverterCalculatorConvert";
id<MTLComputeCommandEncoder> compute_encoder = id<MTLComputeCommandEncoder> compute_encoder =

View File

@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
case Tensor::ElementType::kInt8: case Tensor::ElementType::kInt8:
Dequantize<int8>(input_tensor, &output_tensors->back()); Dequantize<int8>(input_tensor, &output_tensors->back());
break; break;
case Tensor::ElementType::kBool:
Dequantize<bool>(input_tensor, &output_tensors->back());
break;
default: default:
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"Unsupported input tensor type: ", input_tensor.element_type())); "Unsupported input tensor type: ", input_tensor.element_type()));

View File

@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
ValidateResult(GetOutput(), {-1.007874, 0, 1}); ValidateResult(GetOutput(), {-1.007874, 0, 1});
} }
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
std::vector<bool> tensor = {true, false, true};
PushTensor(Tensor::ElementType::kBool, tensor,
Tensor::QuantizationParameters{1.0f, 0});
MP_ASSERT_OK(runner_.Run());
ValidateResult(GetOutput(), {1, 0, 1});
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
} }
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
const auto& input_tensors = *kInTensors(cc); const auto& input_tensors = *kInTensors(cc);
RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32); RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
auto raw_scores = view.buffer<float>(); auto raw_scores = view.buffer<float>();
auto classification_list = absl::make_unique<ClassificationList>(); auto classification_list = absl::make_unique<ClassificationList>();
if (options.has_tensor_index()) {
classification_list->set_tensor_index(options.tensor_index());
}
if (options.has_tensor_name()) {
classification_list->set_tensor_name(options.tensor_name());
}
if (is_binary_classification_) { if (is_binary_classification_) {
Classification* class_first = classification_list->add_classification(); Classification* class_first = classification_list->add_classification();
Classification* class_second = classification_list->add_classification(); Classification* class_second = classification_list->add_classification();

View File

@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions {
// that are not in the `allow_classes` field will be completely ignored. // that are not in the `allow_classes` field will be completely ignored.
// `ignore_classes` and `allow_classes` are mutually exclusive. // `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 allow_classes = 8 [packed = true]; repeated int32 allow_classes = 8 [packed = true];
// The optional index of the tensor these classifications originate from.
optional int32 tensor_index = 10;
// The optional name of the tensor these classifications originate from.
optional string tensor_name = 11;
} }

View File

@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest,
} }
} }
TEST_F(TensorsToClassificationCalculatorTest,
CorrectOutputWithTensorNameAndIndex) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
tensor_index: 1
tensor_name: "foo"
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(3, classification_list.classification_size());
// Verify that the tensor_index and tensor_name fields are correctly set.
EXPECT_EQ(classification_list.tensor_index(), 1);
EXPECT_EQ(classification_list.tensor_name(), "foo");
}
TEST_F(TensorsToClassificationCalculatorTest, TEST_F(TensorsToClassificationCalculatorTest,
ClassNameAllowlistWithLabelItems) { ClassNameAllowlistWithLabelItems) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(

View File

@ -532,7 +532,6 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
detection_classes.data(), detection_classes.data(),
output_detections)); output_detections));
#elif MEDIAPIPE_METAL_ENABLED #elif MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice;
if (!anchors_init_) { if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) { if (input_tensors.size() == kNumInputTensorsWithAnchors) {
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);

View File

@ -67,9 +67,7 @@ absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
"tensor_vec has size $0, expected 1", tensor_vec.size())); "tensor_vec has size $0, expected 1", tensor_vec.size()));
} }
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute( return absl::InvalidArgumentError("Expected tensor element type kChar");
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
} }
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>(); const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
return std::string(buffer, text.length()); return std::string(buffer, text.length());

View File

@ -0,0 +1,167 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <array>
#include <cstring>
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
constexpr absl::string_view kResponseContextMetadataName = "res_context";
constexpr absl::string_view kResponseTextMetadataName = "res_text";
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
// Preprocesses input text into three kTfLiteString input tensors for a
// Universal Sentence Encoder (USE) model.
//
// The associated USE model is expected to contain input tensors with metadata
// names:
//
// Tensor | Metadata Name
// ---------------- | ------------------
// Query text | "inp_text"
// Response context | "res_context"
// Response text | "res_text"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have metadata names corresponding to the
// above names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// Inputs:
// TEXT - std::string
// The text to be embedded.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the USE model. Used to determine the order of
// the three input Tensors for the USE model.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the USE model. The tensors
// fit a question-answering setting and store a query text, a response
// context, and a response text. This calculator will just be preprocessing
// a single input text that will be stored in the response text tensor. The
// query text and response context tensors will store empty strings.
//
// Example:
// node {
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// }
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Indices of the three input tensors for the USE model. They should form the
// set {0, 1, 2}.
int query_text_tensor_index_ = 0;
int response_context_tensor_index_ = 1;
int response_text_tensor_index_ = 2;
// Tensor shapes for the model's input tensors.
// The query text and response context tensors will only hold the empty
// string, so their tensors will have shape [0], but the Universal Sentence
// Encoder model's input signature requires them to be present. The response
// text tensor will store the embedding text and have shape
// [embedding_text_len].
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
};
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
query_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kQueryTextMetadataName);
response_context_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseContextMetadataName);
response_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseTextMetadataName);
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
{query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_});
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_));
}
return absl::OkStatus();
}
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
CalculatorContext* cc) {
absl::string_view text = kTextIn(cc).Get();
const int text_len = static_cast<int>(text.length());
tensor_shapes_[response_text_tensor_index_] = text_len;
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
}
std::memcpy(
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_context_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_text_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
text.data(), text_len * sizeof(char));
kTensorsOut(cc).Send(std::move(input_tensors));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,109 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::IsOkAndHolds;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/"
"universal_sentence_encoder_qa_with_metadata.tflite";
absl::StatusOr<std::vector<std::string>>
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer =
tasks::core::LoadBinaryContent(kTestModelPath.data());
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected $1", tensor_vec.size(),
kNumInputTensorsForUniversalSentenceEncoder));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError("Expected tensor element type kChar");
}
std::vector<std::string> results;
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
results.push_back(
{tensor_vec[i].GetCpuReadView().buffer<char>(),
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
}
return results;
}
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
ASSERT_THAT(
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
}
} // namespace
} // namespace mediapipe

View File

@ -331,6 +331,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -499,7 +499,6 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
gpu_data_out_ = absl::make_unique<GPUData>(); gpu_data_out_ = absl::make_unique<GPUData>();
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_; gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
const bool include_alpha = (max_num_channels_ == 4); const bool include_alpha = (max_num_channels_ == 4);
const bool single_channel = (max_num_channels_ == 1);
if (!(format == mediapipe::ImageFormat::GRAY8 || if (!(format == mediapipe::ImageFormat::GRAY8 ||
format == mediapipe::ImageFormat::SRGB || format == mediapipe::ImageFormat::SRGB ||
format == mediapipe::ImageFormat::SRGBA)) format == mediapipe::ImageFormat::SRGBA))
@ -509,6 +508,7 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED #endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
const bool single_channel = (max_num_channels_ == 1);
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> absl::Status { [this, &include_alpha, &input, &single_channel]() -> absl::Status {
// Device memory. // Device memory.

View File

@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -81,6 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
} }
if (cc->InputSidePackets().HasTag("MODEL_FD")) { if (cc->InputSidePackets().HasTag("MODEL_FD")) {
#ifdef ABSL_HAVE_MMAP
model_packet = cc->InputSidePackets().Tag("MODEL_FD"); model_packet = cc->InputSidePackets().Tag("MODEL_FD");
const auto& model_fd = const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>(); model_packet.Get<std::tuple<int, size_t, size_t>>();
@ -89,6 +91,10 @@ class TfLiteModelCalculator : public CalculatorBase {
tflite::DefaultErrorReporter()); tflite::DefaultErrorReporter());
model = tflite::FlatBufferModel::BuildFromAllocation( model = tflite::FlatBufferModel::BuildFromAllocation(
std::move(model_allocation), tflite::DefaultErrorReporter()); std::move(model_allocation), tflite::DefaultErrorReporter());
#else
return absl::FailedPreconditionError(
"Loading by file descriptor is not supported on this platform.");
#endif
} }
RET_CHECK(model) << "Failed to load TfLite model from blob."; RET_CHECK(model) << "Failed to load TfLite model from blob.";

View File

@ -143,9 +143,7 @@ mediapipe_proto_library(
cc_library( cc_library(
name = "packet_frequency_calculator", name = "packet_frequency_calculator",
srcs = ["packet_frequency_calculator.cc"], srcs = ["packet_frequency_calculator.cc"],
visibility = [ visibility = ["//visibility:public"],
"//visibility:public",
],
deps = [ deps = [
"//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto",
"//mediapipe/calculators/util:packet_frequency_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto",
@ -190,9 +188,7 @@ cc_test(
cc_library( cc_library(
name = "packet_latency_calculator", name = "packet_latency_calculator",
srcs = ["packet_latency_calculator.cc"], srcs = ["packet_latency_calculator.cc"],
visibility = [ visibility = ["//visibility:public"],
"//visibility:public",
],
deps = [ deps = [
"//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:latency_cc_proto",
"//mediapipe/calculators/util:packet_latency_calculator_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto",

View File

@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
text->set_left(label_left_px_); text->set_left(label_left_px_);
text->set_baseline(label_baseline_px + i * label_height_px_); text->set_baseline(label_baseline_px + i * label_height_px_);
text->set_font_face(options_.font_face()); text->set_font_face(options_.font_face());
if (options_.outline_thickness() > 0) {
text->set_outline_thickness(options_.outline_thickness());
if (options_.outline_color_size() > 0) {
*(text->mutable_outline_color()) =
options_.outline_color(i % options_.outline_color_size());
} else {
text->mutable_outline_color()->set_r(0);
text->mutable_outline_color()->set_g(0);
text->mutable_outline_color()->set_b(0);
}
}
} }
cc->Outputs() cc->Outputs()
.Tag(kRenderDataTag) .Tag(kRenderDataTag)

View File

@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions {
// Thickness for drawing the label(s). // Thickness for drawing the label(s).
optional double thickness = 2 [default = 2]; optional double thickness = 2 [default = 2];
// Color of outline around each character, if any. One per label, as with
// color attribute.
repeated Color outline_color = 12;
// Thickness of outline around each character.
optional double outline_thickness = 11;
// The font height in absolute pixels. // The font height in absolute pixels.
optional int32 font_height_px = 3 [default = 50]; optional int32 font_height_px = 3 [default = 50];

View File

@ -18,7 +18,7 @@ import android.content.ClipDescription;
import android.content.Context; import android.content.Context;
import android.net.Uri; import android.net.Uri;
import android.os.Bundle; import android.os.Bundle;
import androidx.appcompat.widget.AppCompatEditText; import android.support.v7.widget.AppCompatEditText;
import android.util.AttributeSet; import android.util.AttributeSet;
import android.util.Log; import android.util.Log;
import android.view.inputmethod.EditorInfo; import android.view.inputmethod.EditorInfo;

View File

@ -1685,10 +1685,3 @@ cc_test(
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -14,15 +14,10 @@ cc_library(
name = "builder", name = "builder",
hdrs = ["builder.h"], hdrs = ["builder.h"],
deps = [ deps = [
":const_str",
":contract",
":node",
":packet",
":port", ":port",
"//mediapipe/framework:calculator_base", "//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -5,12 +5,7 @@
#include <type_traits> #include <type_traits>
#include "absl/container/btree_map.h" #include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/contract.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_base.h" #include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
@ -112,6 +107,17 @@ class MultiPort : public Single {
std::vector<std::unique_ptr<Base>>& vec_; std::vector<std::unique_ptr<Base>>& vec_;
}; };
namespace internal_builder {
template <typename T, typename U>
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>>;
} // namespace internal_builder
template <bool IsSide, typename T = internal::Generic>
class SourceImpl;
// These classes wrap references to the underlying source/destination // These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API. // endpoints, adding type information and the user-visible API.
template <bool IsSide, typename T = internal::Generic> template <bool IsSide, typename T = internal::Generic>
@ -122,16 +128,21 @@ class DestinationImpl {
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec) explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {} : DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
explicit DestinationImpl(DestinationBase* base) : base_(*base) {} explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
template <typename U,
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
DestinationImpl<IsSide, U> Cast() {
return DestinationImpl<IsSide, U>(&base_);
}
private:
DestinationBase& base_; DestinationBase& base_;
template <bool Source_IsSide, typename Source_T>
friend class SourceImpl;
}; };
template <bool IsSide, typename T> template <bool IsSide, typename T>
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> {
public:
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort;
};
template <bool IsSide, typename T = internal::Generic>
class SourceImpl { class SourceImpl {
public: public:
using Base = SourceBase; using Base = SourceBase;
@ -171,12 +182,8 @@ class SourceImpl {
return AddTarget(dest); return AddTarget(dest);
} }
template <typename U> template <typename U,
struct AllowCast std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>> {};
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
SourceImpl<IsSide, U> Cast() { SourceImpl<IsSide, U> Cast() {
return SourceImpl<IsSide, U>(base_); return SourceImpl<IsSide, U>(base_);
} }
@ -186,12 +193,6 @@ class SourceImpl {
SourceBase* base_; SourceBase* base_;
}; };
template <bool IsSide, typename T>
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
public:
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
};
// A source and a destination correspond to an output/input stream on a node, // A source and a destination correspond to an output/input stream on a node,
// and a side source and side destination correspond to an output/input side // and a side source and side destination correspond to an output/input side
// packet. // packet.
@ -201,20 +202,20 @@ class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
template <typename T = internal::Generic> template <typename T = internal::Generic>
using Source = SourceImpl<false, T>; using Source = SourceImpl<false, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSource = MultiSourceImpl<false, T>; using MultiSource = MultiPort<Source<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using SideSource = SourceImpl<true, T>; using SideSource = SourceImpl<true, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSideSource = MultiSourceImpl<true, T>; using MultiSideSource = MultiPort<SideSource<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using Destination = DestinationImpl<false, T>; using Destination = DestinationImpl<false, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using SideDestination = DestinationImpl<true, T>; using SideDestination = DestinationImpl<true, T>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiDestination = MultiDestinationImpl<false, T>; using MultiDestination = MultiPort<Destination<T>>;
template <typename T = internal::Generic> template <typename T = internal::Generic>
using MultiSideDestination = MultiDestinationImpl<true, T>; using MultiSideDestination = MultiPort<SideDestination<T>>;
class NodeBase { class NodeBase {
public: public:
@ -439,8 +440,9 @@ class Graph {
// Creates a node of a specific type. Should be used for pure interfaces, // Creates a node of a specific type. Should be used for pure interfaces,
// which do not have a built-in type string. // which do not have a built-in type string.
template <class Calc> template <class Calc>
Node<Calc>& AddNode(const std::string& type) { Node<Calc>& AddNode(absl::string_view type) {
auto node = std::make_unique<Node<Calc>>(type); auto node =
std::make_unique<Node<Calc>>(std::string(type.data(), type.size()));
auto node_p = node.get(); auto node_p = node.get();
nodes_.emplace_back(std::move(node)); nodes_.emplace_back(std::move(node));
return *node_p; return *node_p;
@ -448,16 +450,18 @@ class Graph {
// Creates a generic node, with no compile-time checking of inputs and // Creates a generic node, with no compile-time checking of inputs and
// outputs. This can be used for calculators whose contract is not visible. // outputs. This can be used for calculators whose contract is not visible.
GenericNode& AddNode(const std::string& type) { GenericNode& AddNode(absl::string_view type) {
auto node = std::make_unique<GenericNode>(type); auto node =
std::make_unique<GenericNode>(std::string(type.data(), type.size()));
auto node_p = node.get(); auto node_p = node.get();
nodes_.emplace_back(std::move(node)); nodes_.emplace_back(std::move(node));
return *node_p; return *node_p;
} }
// For legacy PacketGenerators. // For legacy PacketGenerators.
PacketGenerator& AddPacketGenerator(const std::string& type) { PacketGenerator& AddPacketGenerator(absl::string_view type) {
auto node = std::make_unique<PacketGenerator>(type); auto node = std::make_unique<PacketGenerator>(
std::string(type.data(), type.size()));
auto node_p = node.get(); auto node_p = node.get();
packet_gens_.emplace_back(std::move(node)); packet_gens_.emplace_back(std::move(node));
return *node_p; return *node_p;

View File

@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>(); node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output.SetName("any_type_output"); any_type_output.SetName("any_type_output");
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
CalculatorGraphConfig expected = CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node { node {
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
output_stream: "ANY_OUTPUT:any_type_output" output_stream: "ANY_OUTPUT:any_type_output"
} }
input_stream: "GRAPH_ANY_INPUT:__stream_0" input_stream: "GRAPH_ANY_INPUT:__stream_0"
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
)pb"); )pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
} }

View File

@ -185,7 +185,7 @@ class CalculatorBaseFactory {
// Functions for checking that the calculator has the required GetContract. // Functions for checking that the calculator has the required GetContract.
template <class T> template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
typedef absl::Status (*GetContractType)(CalculatorContract * cc); typedef absl::Status (*GetContractType)(CalculatorContract* cc);
return std::is_same<decltype(&T::GetContract), GetContractType>::value; return std::is_same<decltype(&T::GetContract), GetContractType>::value;
} }
template <class T> template <class T>

View File

@ -133,7 +133,12 @@ message GraphTrace {
TPU_TASK = 13; TPU_TASK = 13;
GPU_CALIBRATION = 14; GPU_CALIBRATION = 14;
PACKET_QUEUED = 15; PACKET_QUEUED = 15;
GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17;
} }
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
// )
// The timing for one packet set being processed at one caclulator node. // The timing for one packet set being processed at one caclulator node.
message CalculatorTrace { message CalculatorTrace {

View File

@ -334,13 +334,6 @@ mediapipe_register_type(
deps = [":landmark_cc_proto"], deps = [":landmark_cc_proto"],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)
cc_library( cc_library(
name = "image", name = "image",
srcs = ["image.cc"], srcs = ["image.cc"],
@ -469,6 +462,10 @@ cc_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}), }),
defines = select({
"//mediapipe/framework:android_no_jni": ["MEDIAPIPE_NO_JNI"],
"//conditions:default": [],
}),
linkopts = select({ linkopts = select({
"//mediapipe:ios": [ "//mediapipe:ios": [
"-framework CoreVideo", "-framework CoreVideo",

View File

@ -33,10 +33,3 @@ mediapipe_proto_library(
srcs = ["rasterization.proto"], srcs = ["rasterization.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -37,6 +37,10 @@ message Classification {
// Group of Classification protos. // Group of Classification protos.
message ClassificationList { message ClassificationList {
repeated Classification classification = 1; repeated Classification classification = 1;
// Optional index of the tensor that produced these classifications.
optional int32 tensor_index = 2;
// Optional name of the tensor that produced these classifications.
optional string tensor_name = 3;
} }
// Group of ClassificationList protos. // Group of ClassificationList protos.

View File

@ -31,11 +31,12 @@
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
#import <Metal/Metal.h> #import <Metal/Metal.h>
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#ifndef MEDIAPIPE_NO_JNI
#if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) #if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
#define MEDIAPIPE_TENSOR_USE_AHWB 1 #define MEDIAPIPE_TENSOR_USE_AHWB 1
#endif // __ANDROID_API__ >= 26 || #endif // __ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__) // defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
#endif // MEDIAPIPE_NO_JNI
#ifdef MEDIAPIPE_TENSOR_USE_AHWB #ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include <android/hardware_buffer.h> #include <android/hardware_buffer.h>
@ -43,7 +44,6 @@
#include "third_party/GL/gl/include/EGL/egl.h" #include "third_party/GL/gl/include/EGL/egl.h"
#include "third_party/GL/gl/include/EGL/eglext.h" #include "third_party/GL/gl/include/EGL/eglext.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB #endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_context.h"
@ -97,8 +97,8 @@ class Tensor {
kUInt8, kUInt8,
kInt8, kInt8,
kInt32, kInt32,
// TODO: Update the inference runner to handle kTfLiteString. kChar,
kChar kBool
}; };
struct Shape { struct Shape {
Shape() = default; Shape() = default;
@ -330,6 +330,8 @@ class Tensor {
return sizeof(int32_t); return sizeof(int32_t);
case ElementType::kChar: case ElementType::kChar:
return sizeof(char); return sizeof(char);
case ElementType::kBool:
return sizeof(bool);
} }
} }
int bytes() const { return shape_.num_elements() * element_size(); } int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) { if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
// EGLSync is failed. Use another synchronization method. // EGLSync is failed. Use another synchronization method.
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync. // TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
glFinish(); gl_context_->Run([]() { glFinish(); });
} else if (valid_ & kValidAHardwareBuffer) { } else if (valid_ & kValidAHardwareBuffer) {
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the " CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
"completion function to be set"; "completion function to be set";

View File

@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4}); Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char)); EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
} }
TEST(Cpu, TestMemoryAllocation) { TEST(Cpu, TestMemoryAllocation) {

View File

@ -109,6 +109,11 @@ struct TraceEvent {
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK; static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION; static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED; static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
// )
}; };
// Packet trace log buffer. // Packet trace log buffer.

View File

@ -64,7 +64,7 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
std::vector<TraceEventType> basic_types = { std::vector<TraceEventType> basic_types = {
{TraceEvent::UNKNOWN, "An uninitialized trace-event."}, {TraceEvent::UNKNOWN, "An uninitialized trace-event."},
{TraceEvent::OPEN, "A call to Calculator::Open.", true, true}, {TraceEvent::OPEN, "A call to Calculator::Open.", true, true},
{TraceEvent::PROCESS, "A call to Calculator::Open.", true, true}, {TraceEvent::PROCESS, "A call to Calculator::Process.", true, true},
{TraceEvent::CLOSE, "A call to Calculator::Close.", true, true}, {TraceEvent::CLOSE, "A call to Calculator::Close.", true, true},
{TraceEvent::NOT_READY, "A calculator cannot process packets yet."}, {TraceEvent::NOT_READY, "A calculator cannot process packets yet."},

View File

@ -150,7 +150,7 @@ cc_library(
name = "executor_util", name = "executor_util",
srcs = ["executor_util.cc"], srcs = ["executor_util.cc"],
hdrs = ["executor_util.h"], hdrs = ["executor_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:mediapipe_options_cc_proto", "//mediapipe/framework:mediapipe_options_cc_proto",

View File

@ -378,8 +378,11 @@ cc_library(
], ],
}), }),
deps = [ deps = [
":gl_texture_buffer",
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage", ":gpu_buffer_storage",
":image_frame_view",
"//mediapipe/framework/formats:image_frame",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
], ],
) )
@ -1050,7 +1053,7 @@ objc_library(
alwayslink = 1, alwayslink = 1,
) )
MIN_IOS_VERSION = "9.0" # For thread_local. MIN_IOS_VERSION = "11.0"
test_suite( test_suite(
name = "ios", name = "ios",

View File

@ -111,7 +111,8 @@ typedef CVOpenGLESTextureCacheRef CVTextureCacheType;
- (CVMetalTextureCacheRef)mtlTextureCache { - (CVMetalTextureCacheRef)mtlTextureCache {
@synchronized(self) { @synchronized(self) {
if (!_mtlTextureCache) { if (!_mtlTextureCache) {
CVReturn err = CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache); CVReturn __unused err =
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err); NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err);
// TODO: register and flush metal caches too. // TODO: register and flush metal caches too.
} }

View File

@ -47,6 +47,8 @@ static void EglThreadExitCallback(void* key_value) {
// implementations, and should be considered as an undocumented vendor // implementations, and should be considered as an undocumented vendor
// extension. // extension.
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml // https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
//
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE, eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
EGL_NO_SURFACE, EGL_NO_CONTEXT); EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif #endif

View File

@ -78,7 +78,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) {
16, 16,
0}; 0};
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs]; pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_2_1];
} }
if (!pixel_format_) { if (!pixel_format_) {
// On several Forge machines, the default config fails. For now let's do // On several Forge machines, the default config fails. For now let's do

View File

@ -144,14 +144,23 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
context](std::shared_ptr<GlSyncPoint> sync_token) { context](std::shared_ptr<GlSyncPoint> sync_token) {
CHECK_NE(name_, 0); CHECK_NE(name_, 0);
GLuint name_to_delete = name_; GLuint name_to_delete = name_;
context->RunWithoutWaiting([name_to_delete, sync_token]() { context->RunWithoutWaiting([name_to_delete]() {
if (sync_token) { // Note that we do not wait for consumers to be done before deleting the
// TODO: maybe we do not actually have to wait for the // texture. Based on a reading of the GLES 3.0 spec, appendix D:
// consumer sync here. Check docs. // - when a texture is deleted, it is _not_ automatically unbound from
sync_token->WaitOnGpu(); // bind points in other contexts;
} else { // - when a texture is deleted, its name becomes immediately invalid, but
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback"; // the actual object is not deleted until it is no longer in use, i.e.
} // attached to a container object or bound to a context;
// - deleting an object is not an operation that changes its contents;
// - within each context, commands are executed sequentially, so it seems
// like an unbind that follows a command that reads a texture should not
// take effect until the GPU has actually finished executing the
// previous commands.
// The final point is the least explicit in the docs, but it is implied by
// normal single-context behavior. E.g. if you do bind, delete, render,
// unbind, the object is not deleted until the unbind, and it waits for
// the render to finish.
DLOG_IF(ERROR, !glIsTexture(name_to_delete)) DLOG_IF(ERROR, !glIsTexture(name_to_delete))
<< "Deleting invalid texture id: " << name_to_delete; << "Deleting invalid texture id: " << name_to_delete;
glDeleteTextures(1, &name_to_delete); glDeleteTextures(1, &name_to_delete);
@ -185,7 +194,10 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
<< "Updated existing texture which had not been marked for reuse!"; << "Updated existing texture which had not been marked for reuse!";
CHECK(prod_token); CHECK(prod_token);
producer_sync_ = std::move(prod_token); producer_sync_ = std::move(prod_token);
producer_context_ = producer_sync_->GetContext(); const auto& synced_context = producer_sync_->GetContext();
if (synced_context) {
producer_context_ = synced_context;
}
} }
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const { void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {

View File

@ -65,6 +65,7 @@ class GlTextureView {
friend class GpuBuffer; friend class GpuBuffer;
friend class GlTextureBuffer; friend class GlTextureBuffer;
friend class GpuBufferStorageCvPixelBuffer; friend class GpuBufferStorageCvPixelBuffer;
friend class GpuBufferStorageAhwb;
GlTextureView(GlContext* context, GLenum target, GLuint name, int width, GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane, int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
DetachFn detach, DoneWritingFn done_writing) DetachFn detach, DoneWritingFn done_writing)

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/tool/test_util.h" #include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h"
#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h" #include "mediapipe/gpu/gpu_buffer_storage_image_frame.h"
#include "mediapipe/gpu/gpu_test_base.h" #include "mediapipe/gpu/gpu_test_base.h"
#include "stb_image.h" #include "stb_image.h"

View File

@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.BitmapExtractor; import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor; import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.Image; import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.ImageProperties; import com.google.mediapipe.framework.image.MPImageProperties;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
// TODO: use Preconditions in this file. // TODO: use Preconditions in this file.
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
} }
/** /**
* Creates an Image packet from an {@link Image}. * Creates a MediaPipe Image packet from a {@link MPImage}.
* *
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. * <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
*/ */
public Packet createImage(Image image) { public Packet createImage(MPImage image) {
// TODO: Choose the best storage from multiple containers. // TODO: Choose the best storage from multiple containers.
ImageProperties properties = image.getContainedImageProperties().get(0); MPImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
ByteBuffer buffer = ByteBufferExtractor.extract(image); ByteBuffer buffer = ByteBufferExtractor.extract(image);
int numChannels = 0; int numChannels = 0;
switch (properties.getImageFormat()) { switch (properties.getImageFormat()) {
case Image.IMAGE_FORMAT_RGBA: case MPImage.IMAGE_FORMAT_RGBA:
numChannels = 4; numChannels = 4;
break; break;
case Image.IMAGE_FORMAT_RGB: case MPImage.IMAGE_FORMAT_RGB:
numChannels = 3; numChannels = 3;
break; break;
case Image.IMAGE_FORMAT_ALPHA: case MPImage.IMAGE_FORMAT_ALPHA:
numChannels = 1; numChannels = 1;
break; break;
default: // fall out default: // fall out
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
int height = image.getHeight(); int height = image.getHeight();
return createImage(buffer, width, height, numChannels); return createImage(buffer, width, height, numChannels);
} }
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) {
Bitmap bitmap = BitmapExtractor.extract(image); Bitmap bitmap = BitmapExtractor.extract(image);
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");

View File

@ -30,3 +30,10 @@ android_library(
"@maven//:com_google_guava_guava", "@maven//:com_google_guava_guava",
], ],
) )
# Expose the java source files for building mediapipe AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
import android.graphics.Bitmap; import android.graphics.Bitmap;
/** /**
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. * Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
* {@link IllegalArgumentException} will be thrown. * {@link IllegalArgumentException} will be thrown.
*/ */
public final class BitmapExtractor { public final class BitmapExtractor {
/** /**
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}. * Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
* *
* @param image the image to extract {@link android.graphics.Bitmap} from. * @param image the image to extract {@link android.graphics.Bitmap} from.
* @return the {@link android.graphics.Bitmap} stored in {@link Image} * @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type * @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions. * conversions.
*/ */
public static Bitmap extract(Image image) { public static Bitmap extract(MPImage image) {
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
if (imageContainer != null) { if (imageContainer != null) {
return ((BitmapImageContainer) imageContainer).getBitmap(); return ((BitmapImageContainer) imageContainer).getBitmap();
} else { } else {
// TODO: Support ByteBuffer -> Bitmap conversion. // TODO: Support ByteBuffer -> Bitmap conversion.
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting Bitmap from an Image created by objects other than Bitmap is not" "Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
+ " supported"); + " supported");
} }
} }

View File

@ -22,7 +22,7 @@ import android.provider.MediaStore;
import java.io.IOException; import java.io.IOException;
/** /**
* Builds {@link Image} from {@link android.graphics.Bitmap}. * Builds {@link MPImage} from {@link android.graphics.Bitmap}.
* *
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once * <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
} }
/** /**
* Creates the builder to build {@link Image} from a file. * Creates the builder to build {@link MPImage} from a file.
* *
* @param context the application context. * @param context the application context.
* @param uri the path to the resource file. * @param uri the path to the resource file.
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
BitmapImageBuilder setTimestamp(long timestamp) { BitmapImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image( return new MPImage(
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
} }
} }

View File

@ -16,19 +16,19 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.graphics.Bitmap; import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
class BitmapImageContainer implements ImageContainer { class BitmapImageContainer implements MPImageContainer {
private final Bitmap bitmap; private final Bitmap bitmap;
private final ImageProperties properties; private final MPImageProperties properties;
public BitmapImageContainer(Bitmap bitmap) { public BitmapImageContainer(Bitmap bitmap) {
this.bitmap = bitmap; this.bitmap = bitmap;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setImageFormat(convertFormatCode(bitmap.getConfig())) .setImageFormat(convertFormatCode(bitmap.getConfig()))
.setStorageType(Image.STORAGE_TYPE_BITMAP) .setStorageType(MPImage.STORAGE_TYPE_BITMAP)
.build(); .build();
} }
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
bitmap.recycle(); bitmap.recycle();
} }
@ImageFormat @MPImageFormat
static int convertFormatCode(Bitmap.Config config) { static int convertFormatCode(Bitmap.Config config) {
switch (config) { switch (config) {
case ALPHA_8: case ALPHA_8:
return Image.IMAGE_FORMAT_ALPHA; return MPImage.IMAGE_FORMAT_ALPHA;
case ARGB_8888: case ARGB_8888:
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
default: default:
return Image.IMAGE_FORMAT_UNKNOWN; return MPImage.IMAGE_FORMAT_UNKNOWN;
} }
} }
} }

View File

@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
import android.os.Build.VERSION; import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.Locale; import java.util.Locale;
/** /**
* Utility for extracting {@link ByteBuffer} from {@link Image}. * Utility for extracting {@link ByteBuffer} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
* {@link IllegalArgumentException} will be thrown. * otherwise {@link IllegalArgumentException} will be thrown.
*/ */
public class ByteBufferExtractor { public class ByteBufferExtractor {
/** /**
* Extracts a {@link ByteBuffer} from an {@link Image}. * Extracts a {@link ByteBuffer} from a {@link MPImage}.
* *
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. * MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
* *
* @see Image#getContainedImageProperties() * @see MPImage#getContainedImageProperties()
* @return A read-only {@link ByteBuffer}. * @return A read-only {@link ByteBuffer}.
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
*/ */
@SuppressLint("SwitchIntDef") @SuppressLint("SwitchIntDef")
public static ByteBuffer extract(Image image) { public static ByteBuffer extract(MPImage image) {
ImageContainer container = image.getContainer(); MPImageContainer container = image.getContainer();
switch (container.getImageProperties().getStorageType()) { switch (container.getImageProperties().getStorageType()) {
case Image.STORAGE_TYPE_BYTEBUFFER: case MPImage.STORAGE_TYPE_BYTEBUFFER:
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
default: default:
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" "Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
+ " supported"); + " supported");
} }
} }
/** /**
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
* *
* <p>Format conversion spec: * <p>Format conversion spec:
* *
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
* *
* @param image the image to extract buffer from. * @param image the image to extract buffer from.
* @param targetFormat the image format of the result bytebuffer. * @param targetFormat the image format of the result bytebuffer.
* @return the readonly {@link ByteBuffer} stored in {@link Image} * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type * @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions. * conversions.
*/ */
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
ImageContainer container; MPImageContainer container;
ImageProperties byteBufferProperties = MPImageProperties byteBufferProperties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(targetFormat) .setImageFormat(targetFormat)
.build(); .build();
if ((container = image.getContainer(byteBufferProperties)) != null) { if ((container = image.getContainer(byteBufferProperties)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); @MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
.asReadOnlyBuffer(); .asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
ByteBuffer byteBuffer = ByteBuffer byteBuffer =
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
return byteBuffer; return byteBuffer;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by objects other than Bitmap or" "Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
+ " Bytebuffer is not supported"); + " Bytebuffer is not supported");
} }
} }
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ /** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
@AutoValue @AutoValue
abstract static class Result { abstract static class Result {
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ /**
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
public abstract ByteBuffer buffer(); public abstract ByteBuffer buffer();
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ /**
@ImageFormat * Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
@MPImageFormat
public abstract int format(); public abstract int format();
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
} }
} }
/** /**
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
* *
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
* *
* @return the readonly {@link ByteBuffer} stored in {@link Image} * @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
* given {@code imageFormat} * given {@code imageFormat}
*/ */
static Result extractInRecommendedFormat(Image image) { static Result extractInRecommendedFormat(MPImage image) {
ImageContainer container; MPImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
@ImageFormat int format = adviseImageFormat(bitmap); @MPImageFormat int format = adviseImageFormat(bitmap);
Result result = Result result =
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
boolean unused = boolean unused =
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
return result; return result;
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { } else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return Result.create( return Result.create(
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
byteBufferImageContainer.getImageFormat()); byteBufferImageContainer.getImageFormat());
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" "Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
+ " is not supported"); + " is not supported");
} }
} }
@ImageFormat @MPImageFormat
private static int adviseImageFormat(Bitmap bitmap) { private static int adviseImageFormat(Bitmap bitmap) {
if (bitmap.getConfig() == Config.ARGB_8888) { if (bitmap.getConfig() == Config.ARGB_8888) {
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
} else { } else {
throw new IllegalArgumentException( throw new IllegalArgumentException(
String.format( String.format(
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" "Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
+ " supported", + " supported",
bitmap.getConfig())); bitmap.getConfig()));
} }
} }
private static ByteBuffer extractByteBufferFromBitmap( private static ByteBuffer extractByteBufferFromBitmap(
Bitmap bitmap, @ImageFormat int imageFormat) { Bitmap bitmap, @MPImageFormat int imageFormat) {
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" "Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
+ " supported"); + " supported");
} }
if (bitmap.getConfig() == Config.ARGB_8888) { if (bitmap.getConfig() == Config.ARGB_8888) {
if (imageFormat == Image.IMAGE_FORMAT_RGBA) { if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
bitmap.copyPixelsToBuffer(buffer); bitmap.copyPixelsToBuffer(buffer);
buffer.rewind(); buffer.rewind();
return buffer; return buffer;
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) { } else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster. // TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
int w = bitmap.getWidth(); int w = bitmap.getWidth();
int h = bitmap.getHeight(); int h = bitmap.getHeight();
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
} }
throw new IllegalArgumentException( throw new IllegalArgumentException(
String.format( String.format(
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" "Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
+ " %d is not supported", + " %d is not supported",
bitmap.getConfig(), imageFormat)); bitmap.getConfig(), imageFormat));
} }
private static ByteBuffer convertByteBuffer( private static ByteBuffer convertByteBuffer(
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the // Extend the buffer when the target is longer than the source. Use two cursors and sweep the
// array reversely to convert in-place. // array reversely to convert in-place.
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
target.put(array, 0, target.capacity()); target.put(array, 0, target.capacity());
target.rewind(); target.rewind();
return target; return target;
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { } else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
// array to convert in-place. // array to convert in-place.

View File

@ -15,11 +15,11 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
/** /**
* Builds a {@link Image} from a {@link ByteBuffer}. * Builds a {@link MPImage} from a {@link ByteBuffer}.
* *
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link * <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
private final ByteBuffer buffer; private final ByteBuffer buffer;
private final int width; private final int width;
private final int height; private final int height;
@ImageFormat private final int imageFormat; @MPImageFormat private final int imageFormat;
// Optional fields. // Optional fields.
private long timestamp; private long timestamp;
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
* @param imageFormat how the data encode the image. * @param imageFormat how the data encode the image.
*/ */
public ByteBufferImageBuilder( public ByteBufferImageBuilder(
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
this.buffer = byteBuffer; this.buffer = byteBuffer;
this.width = width; this.width = width;
this.height = height; this.height = height;
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
this.timestamp = 0; this.timestamp = 0;
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
ByteBufferImageBuilder setTimestamp(long timestamp) { ByteBufferImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
} }
} }

View File

@ -15,21 +15,19 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
class ByteBufferImageContainer implements ImageContainer { class ByteBufferImageContainer implements MPImageContainer {
private final ByteBuffer buffer; private final ByteBuffer buffer;
private final ImageProperties properties; private final MPImageProperties properties;
public ByteBufferImageContainer( public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
ByteBuffer buffer,
@ImageFormat int imageFormat) {
this.buffer = buffer; this.buffer = buffer;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) .setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(imageFormat) .setImageFormat(imageFormat)
.build(); .build();
} }
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
/** /** Returns the image format. */
* Returns the image format. @MPImageFormat
*/
@ImageFormat
public int getImageFormat() { public int getImageFormat() {
return properties.getImageFormat(); return properties.getImageFormat();
} }

View File

@ -29,10 +29,10 @@ import java.util.Map.Entry;
/** /**
* The wrapper class for image objects. * The wrapper class for image objects.
* *
* <p>{@link Image} is designed to be an immutable image container, which could be shared * <p>{@link MPImage} is designed to be an immutable image container, which could be shared
* cross-platforms. * cross-platforms.
* *
* <p>To construct an {@link Image}, use the provided builders: * <p>To construct a {@link MPImage}, use the provided builders:
* *
* <ul> * <ul>
* <li>{@link ByteBufferImageBuilder} * <li>{@link ByteBufferImageBuilder}
@ -40,7 +40,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageBuilder} * <li>{@link MediaImageBuilder}
* </ul> * </ul>
* *
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the * <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release * reference count is 1. Developer can call {@link #close()} to reduce reference count to release
* internal storage earlier, otherwise Java garbage collection will release the storage eventually. * internal storage earlier, otherwise Java garbage collection will release the storage eventually.
* *
@ -53,7 +53,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageExtractor} * <li>{@link MediaImageExtractor}
* </ul> * </ul>
*/ */
public class Image implements Closeable { public class MPImage implements Closeable {
/** Specifies the image format of an image. */ /** Specifies the image format of an image. */
@IntDef({ @IntDef({
@ -69,7 +69,7 @@ public class Image implements Closeable {
IMAGE_FORMAT_JPEG, IMAGE_FORMAT_JPEG,
}) })
@Retention(RetentionPolicy.SOURCE) @Retention(RetentionPolicy.SOURCE)
public @interface ImageFormat {} public @interface MPImageFormat {}
public static final int IMAGE_FORMAT_UNKNOWN = 0; public static final int IMAGE_FORMAT_UNKNOWN = 0;
public static final int IMAGE_FORMAT_RGBA = 1; public static final int IMAGE_FORMAT_RGBA = 1;
@ -98,14 +98,14 @@ public class Image implements Closeable {
public static final int STORAGE_TYPE_IMAGE_PROXY = 4; public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
/** /**
* Returns a list of supported image properties for this {@link Image}. * Returns a list of supported image properties for this {@link MPImage}.
* *
* <p>Currently {@link Image} only support single storage type so the size of return list will * <p>Currently {@link MPImage} only support single storage type so the size of return list will
* always be 1. * always be 1.
* *
* @see ImageProperties * @see MPImageProperties
*/ */
public List<ImageProperties> getContainedImageProperties() { public List<MPImageProperties> getContainedImageProperties() {
return Collections.singletonList(getContainer().getImageProperties()); return Collections.singletonList(getContainer().getImageProperties());
} }
@ -124,7 +124,7 @@ public class Image implements Closeable {
return height; return height;
} }
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ /** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
private synchronized void acquire() { private synchronized void acquire() {
referenceCount += 1; referenceCount += 1;
} }
@ -132,7 +132,7 @@ public class Image implements Closeable {
/** /**
* Removes a reference that was previously acquired or init. * Removes a reference that was previously acquired or init.
* *
* <p>When {@link Image} is created, it has 1 reference count. * <p>When {@link MPImage} is created, it has 1 reference count.
* *
* <p>When the reference count becomes 0, it will release the resource under the hood. * <p>When the reference count becomes 0, it will release the resource under the hood.
*/ */
@ -141,24 +141,24 @@ public class Image implements Closeable {
public synchronized void close() { public synchronized void close() {
referenceCount -= 1; referenceCount -= 1;
if (referenceCount == 0) { if (referenceCount == 0) {
for (ImageContainer imageContainer : containerMap.values()) { for (MPImageContainer imageContainer : containerMap.values()) {
imageContainer.close(); imageContainer.close();
} }
} }
} }
/** Advanced API access for {@link Image}. */ /** Advanced API access for {@link MPImage}. */
static final class Internal { static final class Internal {
/** /**
* Acquires a reference on this {@link Image}. This will increase the reference count by 1. * Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
* *
* <p>This method is more useful for image consumer to acquire a reference so image resource * <p>This method is more useful for image consumer to acquire a reference so image resource
* will not be closed accidentally. As image creator, normal developer doesn't need to call this * will not be closed accidentally. As image creator, normal developer doesn't need to call this
* method. * method.
* *
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link * <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link Image} anymore. * #close()} to indicate it doesn't need this {@link MPImage} anymore.
* *
* @see #close() * @see #close()
*/ */
@ -166,10 +166,10 @@ public class Image implements Closeable {
image.acquire(); image.acquire();
} }
private final Image image; private final MPImage image;
// Only Image creates the internal helper. // Only MPImage creates the internal helper.
private Internal(Image image) { private Internal(MPImage image) {
this.image = image; this.image = image;
} }
} }
@ -179,15 +179,15 @@ public class Image implements Closeable {
return new Internal(this); return new Internal(this);
} }
private final Map<ImageProperties, ImageContainer> containerMap; private final Map<MPImageProperties, MPImageContainer> containerMap;
private final long timestamp; private final long timestamp;
private final int width; private final int width;
private final int height; private final int height;
private int referenceCount; private int referenceCount;
/** Constructs an {@link Image} with a built container. */ /** Constructs a {@link MPImage} with a built container. */
Image(ImageContainer container, long timestamp, int width, int height) { MPImage(MPImageContainer container, long timestamp, int width, int height) {
this.containerMap = new HashMap<>(); this.containerMap = new HashMap<>();
containerMap.put(container.getImageProperties(), container); containerMap.put(container.getImageProperties(), container);
this.timestamp = timestamp; this.timestamp = timestamp;
@ -201,10 +201,10 @@ public class Image implements Closeable {
* *
* @return the current container. * @return the current container.
*/ */
ImageContainer getContainer() { MPImageContainer getContainer() {
// According to the design, in the future we will support multiple containers in one image. // According to the design, in the future we will support multiple containers in one image.
// Currently just return the original container. // Currently just return the original container.
// TODO: Cache multiple containers in Image. // TODO: Cache multiple containers in MPImage.
return containerMap.values().iterator().next(); return containerMap.values().iterator().next();
} }
@ -214,8 +214,8 @@ public class Image implements Closeable {
* <p>If there are multiple containers with required {@code storageType}, returns the first one. * <p>If there are multiple containers with required {@code storageType}, returns the first one.
*/ */
@Nullable @Nullable
ImageContainer getContainer(@StorageType int storageType) { MPImageContainer getContainer(@StorageType int storageType) {
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) { for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
if (entry.getKey().getStorageType() == storageType) { if (entry.getKey().getStorageType() == storageType) {
return entry.getValue(); return entry.getValue();
} }
@ -225,13 +225,13 @@ public class Image implements Closeable {
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
@Nullable @Nullable
ImageContainer getContainer(ImageProperties imageProperties) { MPImageContainer getContainer(MPImageProperties imageProperties) {
return containerMap.get(imageProperties); return containerMap.get(imageProperties);
} }
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
boolean addContainer(ImageContainer container) { boolean addContainer(MPImageContainer container) {
ImageProperties imageProperties = container.getImageProperties(); MPImageProperties imageProperties = container.getImageProperties();
if (containerMap.containsKey(imageProperties)) { if (containerMap.containsKey(imageProperties)) {
return false; return false;
} }

View File

@ -14,14 +14,14 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that can receive {@link Image} */ /** Lightweight abstraction for an object that can receive {@link MPImage} */
public interface ImageConsumer { public interface MPImageConsumer {
/** /**
* Called when an {@link Image} is available. * Called when a {@link MPImage} is available.
* *
* <p>The argument is only guaranteed to be available until this method returns. if you need to * <p>The argument is only guaranteed to be available until this method returns. if you need to
* extend its life time, acquire it, then release it when done. * extend its life time, acquire it, then release it when done.
*/ */
void onNewImage(Image image); void onNewMPImage(MPImage image);
} }

View File

@ -16,9 +16,9 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Manages internal image data storage. The interface is package-private. */ /** Manages internal image data storage. The interface is package-private. */
interface ImageContainer { interface MPImageContainer {
/** Returns the properties of the contained image. */ /** Returns the properties of the contained image. */
ImageProperties getImageProperties(); MPImageProperties getImageProperties();
/** Close the image container and releases the image resource inside. */ /** Close the image container and releases the image resource inside. */
void close(); void close();

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that produce {@link Image} */ /** Lightweight abstraction for an object that produce {@link MPImage} */
public interface ImageProducer { public interface MPImageProducer {
/** Sets the consumer that receives the {@link Image}. */ /** Sets the consumer that receives the {@link MPImage}. */
void setImageConsumer(ImageConsumer imageConsumer); void setMPImageConsumer(MPImageConsumer imageConsumer);
} }

View File

@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
import com.google.auto.value.AutoValue; import com.google.auto.value.AutoValue;
import com.google.auto.value.extension.memoized.Memoized; import com.google.auto.value.extension.memoized.Memoized;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import com.google.mediapipe.framework.image.Image.StorageType; import com.google.mediapipe.framework.image.MPImage.StorageType;
/** Groups a set of properties to describe how an image is stored. */ /** Groups a set of properties to describe how an image is stored. */
@AutoValue @AutoValue
public abstract class ImageProperties { public abstract class MPImageProperties {
/** /**
* Gets the pixel format of the image. * Gets the pixel format of the image.
* *
* @see Image.ImageFormat * @see MPImage.MPImageFormat
*/ */
@ImageFormat @MPImageFormat
public abstract int getImageFormat(); public abstract int getImageFormat();
/** /**
* Gets the storage type of the image. * Gets the storage type of the image.
* *
* @see Image.StorageType * @see MPImage.StorageType
*/ */
@StorageType @StorageType
public abstract int getStorageType(); public abstract int getStorageType();
@ -45,36 +45,36 @@ public abstract class ImageProperties {
public abstract int hashCode(); public abstract int hashCode();
/** /**
* Creates a builder of {@link ImageProperties}. * Creates a builder of {@link MPImageProperties}.
* *
* @see ImageProperties.Builder * @see MPImageProperties.Builder
*/ */
static Builder builder() { static Builder builder() {
return new AutoValue_ImageProperties.Builder(); return new AutoValue_MPImageProperties.Builder();
} }
/** Builds a {@link ImageProperties}. */ /** Builds a {@link MPImageProperties}. */
@AutoValue.Builder @AutoValue.Builder
abstract static class Builder { abstract static class Builder {
/** /**
* Sets the {@link Image.ImageFormat}. * Sets the {@link MPImage.MPImageFormat}.
* *
* @see ImageProperties#getImageFormat * @see MPImageProperties#getImageFormat
*/ */
abstract Builder setImageFormat(@ImageFormat int value); abstract Builder setImageFormat(@MPImageFormat int value);
/** /**
* Sets the {@link Image.StorageType}. * Sets the {@link MPImage.StorageType}.
* *
* @see ImageProperties#getStorageType * @see MPImageProperties#getStorageType
*/ */
abstract Builder setStorageType(@StorageType int value); abstract Builder setStorageType(@StorageType int value);
/** Builds the {@link ImageProperties}. */ /** Builds the {@link MPImageProperties}. */
abstract ImageProperties build(); abstract MPImageProperties build();
} }
// Hide the constructor. // Hide the constructor.
ImageProperties() {} MPImageProperties() {}
} }

View File

@ -15,11 +15,12 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
/** /**
* Builds {@link Image} from {@link android.media.Image}. * Builds {@link MPImage} from {@link android.media.Image}.
* *
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify * <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
* content in it. * content in it.
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
public class MediaImageBuilder { public class MediaImageBuilder {
// Mandatory fields. // Mandatory fields.
private final android.media.Image mediaImage; private final Image mediaImage;
// Optional fields. // Optional fields.
private long timestamp; private long timestamp;
@ -40,20 +41,20 @@ public class MediaImageBuilder {
* *
* @param mediaImage image data object. * @param mediaImage image data object.
*/ */
public MediaImageBuilder(android.media.Image mediaImage) { public MediaImageBuilder(Image mediaImage) {
this.mediaImage = mediaImage; this.mediaImage = mediaImage;
this.timestamp = 0; this.timestamp = 0;
} }
/** Sets value for {@link Image#getTimestamp()}. */ /** Sets value for {@link MPImage#getTimestamp()}. */
MediaImageBuilder setTimestamp(long timestamp) { MediaImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp; this.timestamp = timestamp;
return this; return this;
} }
/** Builds an {@link Image} instance. */ /** Builds a {@link MPImage} instance. */
public Image build() { public MPImage build() {
return new Image( return new MPImage(
new MediaImageContainer(mediaImage), new MediaImageContainer(mediaImage),
timestamp, timestamp,
mediaImage.getWidth(), mediaImage.getWidth(),

View File

@ -15,33 +15,34 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build; import android.os.Build;
import android.os.Build.VERSION; import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
import com.google.mediapipe.framework.image.Image.ImageFormat; import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
@RequiresApi(VERSION_CODES.KITKAT) @RequiresApi(VERSION_CODES.KITKAT)
class MediaImageContainer implements ImageContainer { class MediaImageContainer implements MPImageContainer {
private final android.media.Image mediaImage; private final Image mediaImage;
private final ImageProperties properties; private final MPImageProperties properties;
public MediaImageContainer(android.media.Image mediaImage) { public MediaImageContainer(Image mediaImage) {
this.mediaImage = mediaImage; this.mediaImage = mediaImage;
this.properties = this.properties =
ImageProperties.builder() MPImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) .setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
.setImageFormat(convertFormatCode(mediaImage.getFormat())) .setImageFormat(convertFormatCode(mediaImage.getFormat()))
.build(); .build();
} }
public android.media.Image getImage() { public Image getImage() {
return mediaImage; return mediaImage;
} }
@Override @Override
public ImageProperties getImageProperties() { public MPImageProperties getImageProperties() {
return properties; return properties;
} }
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
mediaImage.close(); mediaImage.close();
} }
@ImageFormat @MPImageFormat
static int convertFormatCode(int graphicsFormat) { static int convertFormatCode(int graphicsFormat) {
// We only cover the format mentioned in // We only cover the format mentioned in
// https://developer.android.com/reference/android/media/Image#getFormat() // https://developer.android.com/reference/android/media/Image#getFormat()
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
return Image.IMAGE_FORMAT_RGBA; return MPImage.IMAGE_FORMAT_RGBA;
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
return Image.IMAGE_FORMAT_RGB; return MPImage.IMAGE_FORMAT_RGB;
} }
} }
switch (graphicsFormat) { switch (graphicsFormat) {
case android.graphics.ImageFormat.JPEG: case android.graphics.ImageFormat.JPEG:
return Image.IMAGE_FORMAT_JPEG; return MPImage.IMAGE_FORMAT_JPEG;
case android.graphics.ImageFormat.YUV_420_888: case android.graphics.ImageFormat.YUV_420_888:
return Image.IMAGE_FORMAT_YUV_420_888; return MPImage.IMAGE_FORMAT_YUV_420_888;
default: default:
return Image.IMAGE_FORMAT_UNKNOWN; return MPImage.IMAGE_FORMAT_UNKNOWN;
} }
} }
} }

View File

@ -15,13 +15,14 @@ limitations under the License.
package com.google.mediapipe.framework.image; package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES; import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi; import androidx.annotation.RequiresApi;
/** /**
* Utility for extracting {@link android.media.Image} from {@link Image}. * Utility for extracting {@link android.media.Image} from {@link MPImage}.
* *
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, * <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
* otherwise {@link IllegalArgumentException} will be thrown. * otherwise {@link IllegalArgumentException} will be thrown.
*/ */
@RequiresApi(VERSION_CODES.KITKAT) @RequiresApi(VERSION_CODES.KITKAT)
@ -30,20 +31,20 @@ public class MediaImageExtractor {
private MediaImageExtractor() {} private MediaImageExtractor() {}
/** /**
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for * Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
* {@link Image} that built from {@link MediaImageBuilder}. * {@link MPImage} that built from {@link MediaImageBuilder}.
* *
* @param image the image to extract {@link android.media.Image} from. * @param image the image to extract {@link android.media.Image} from.
* @return {@link android.media.Image} that stored in {@link Image}. * @return {@link android.media.Image} that stored in {@link MPImage}.
* @throws IllegalArgumentException if the extraction failed. * @throws IllegalArgumentException if the extraction failed.
*/ */
public static android.media.Image extract(Image image) { public static Image extract(MPImage image) {
ImageContainer container; MPImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
return ((MediaImageContainer) container).getImage(); return ((MediaImageContainer) container).getImage();
} }
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Extract Media Image from an Image created by objects other than Media Image" "Extract Media Image from a MPImage created by objects other than Media Image"
+ " is not supported"); + " is not supported");
} }
} }

View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 The MediaPipe Authors. # Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -89,10 +89,6 @@ def mediapipe_aar(
calculators = calculators, calculators = calculators,
) )
_mediapipe_proto(
name = name + "_proto",
)
native.genrule( native.genrule(
name = name + "_aar_manifest_generator", name = name + "_aar_manifest_generator",
outs = ["AndroidManifest.xml"], outs = ["AndroidManifest.xml"],
@ -115,19 +111,10 @@ EOF
"//mediapipe/java/com/google/mediapipe/components:java_src", "//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src", "//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src", "//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java", ] + mediapipe_java_proto_srcs() +
"com/google/mediapipe/formats/proto/ClassificationProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/proto/CalculatorProto.java",
] +
select({ select({
"//conditions:default": [], "//conditions:default": [],
"enable_stats_logging": [ "enable_stats_logging": mediapipe_logging_java_proto_srcs(),
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
],
}), }),
manifest = "AndroidManifest.xml", manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"], proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
@ -177,93 +164,9 @@ EOF
assets_dir = assets_dir, assets_dir = assets_dir,
) )
_aar_with_jni(name, name + "_android_lib") mediapipe_build_aar_with_jni(
name = name,
def _mediapipe_proto(name): android_library = name + "_android_lib",
"""Generates MediaPipe java proto libraries.
Args:
name: the name of the target.
"""
_proto_java_src_generator(
name = "mediapipe_log_extension_proto",
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "mediapipe_logging_enums_proto",
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "calculator_proto",
proto_src = "mediapipe/framework/calculator.proto",
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
srcs = ["//mediapipe/framework:protos_src"],
)
_proto_java_src_generator(
name = "landmark_proto",
proto_src = "mediapipe/framework/formats/landmark.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
srcs = ["//mediapipe/framework/formats:protos_src"],
)
_proto_java_src_generator(
name = "rasterization_proto",
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
)
_proto_java_src_generator(
name = "location_data_proto",
proto_src = "mediapipe/framework/formats/location_data.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "detection_proto",
proto_src = "mediapipe/framework/formats/detection.proto",
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "classification_proto",
proto_src = "mediapipe/framework/formats/classification.proto",
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
],
)
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
native.genrule(
name = name + "_proto_java_src_generator",
srcs = srcs + [
"@com_google_protobuf//:lite_well_known_protos",
],
outs = [java_lite_out],
cmd = "$(location @com_google_protobuf//:protoc) " +
"--proto_path=. --proto_path=$(GENDIR) " +
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
tools = [
"@com_google_protobuf//:protoc",
],
) )
def _mediapipe_jni(name, gen_libmediapipe, calculators = []): def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
@ -303,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
alwayslink = 1, alwayslink = 1,
) )
def _aar_with_jni(name, android_library): def mediapipe_build_aar_with_jni(name, android_library):
"""Builds MediaPipe AAR with jni.
Args:
name: The bazel target name.
android_library: the android library that contains jni.
"""
# Generates dummy AndroidManifest.xml for dummy apk usage # Generates dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app target below) # (dummy apk is generated by <name>_dummy_app target below)
native.genrule( native.genrule(
@ -314,7 +224,7 @@ cat > $(OUTS) <<EOF
<manifest <manifest
xmlns:android="http://schemas.android.com/apk/res/android" xmlns:android="http://schemas.android.com/apk/res/android"
package="dummy.package.for.so"> package="dummy.package.for.so">
<uses-sdk android:minSdkVersion="21"/> <uses-sdk android:minSdkVersion="24"/>
</manifest> </manifest>
EOF EOF
""", """,
@ -341,7 +251,133 @@ chmod +w $(location :{}.aar)
origdir=$$PWD origdir=$$PWD
cd $$(mktemp -d) cd $$(mktemp -d)
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*" unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
find lib -name *_dummy_app.so -delete
cp -r lib jni cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name), """.format(android_library, name, name, name, name),
) )
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
"""Extracts the generated MediaPipe java proto source code from the target.
Args:
target: The java proto lite target to be built and extracted.
src_out: The output java proto src code path.
name: The optional bazel target name.
Returns:
The output java proto src code path.
"""
if not name:
name = target.split(":")[-1] + "_proto_java_src_extractor"
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
native.genrule(
name = name + "_proto_java_src_extractor",
srcs = [target],
outs = [src_out],
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
src_out + " $$(dirname $(location " + src_out + "))",
)
return src_out
def mediapipe_java_proto_srcs(name = ""):
"""Extracts the generated MediaPipe framework java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:stream_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StreamHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_factory_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketFactoryProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_generator_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:status_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StatusHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:mediapipe_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:detection_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
))
return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""):
"""Extracts the generated logging-related MediaPipe java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe logging-related java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
))
return proto_src_list

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,26 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,61 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "data_util",
srcs = ["data_util.py"],
)
py_test(
name = "data_util_test",
srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"],
deps = [":data_util"],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
srcs_version = "PY3",
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
deps = [
":dataset",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)
py_library(
name = "classification_dataset",
srcs = ["classification_dataset.py"],
deps = [":dataset"],
)
py_test(
name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"],
deps = [":classification_dataset"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,51 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
super().__init__(dataset, size)
self._index_by_label = index_by_label
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self._index_by_label)
@property
def index_by_label(self: ds._DatasetT) -> Any:
return self._index_by_label
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: float, demonstrates the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self._index_by_label)

View File

@ -0,0 +1,82 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
_DatasetT = TypeVar(
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
class ClassificationDatasetTest(tf.test.TestCase):
def test_split(self):
class MagicClassificationDataset(
classification_dataset.ClassificationDataset):
"""A mock classification dataset class for testing purpose.
Attributes:
value: A value variable stored by the mock dataset class for testing.
"""
def __init__(self, dataset: tf.data.Dataset, size: int,
index_by_label: Any, value: Any):
super().__init__(
dataset=dataset, size=size, index_by_label=index_by_label)
self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_by_label, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_by_label = (False, True)
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset(
dataset=ds,
size=len(ds),
index_by_label=index_by_label,
value=magic_value)
# Train/Test data split.
fraction = .25
train_data, test_data = data.split(fraction=fraction)
# `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataset)
self.assertIsInstance(test_data, MagicClassificationDataset)
# Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
self.assertLen(train_data, fraction * len(ds))
self.assertLen(test_data, len(ds) - len(train_data))
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_by_label, index_by_label)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utility library."""
import cv2
import numpy as np
import tensorflow as tf
def load_image(path: str) -> np.ndarray:
"""Loads an image as an RGB numpy array.
Args:
path: input image file absolute path.
Returns:
An RGB image in numpy.ndarray.
"""
tf.compat.v1.logging.info('Loading RGB image %s', path)
# TODO Replace the OpenCV image load and conversion library by
# MediaPipe image utility library once it is ready.
image = cv2.imread(path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

View File

@ -0,0 +1,44 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
from absl import flags
import tensorflow as tf
from mediapipe.model_maker.python.core.data import data_util
_WORKSPACE = "mediapipe"
_TEST_DATA_DIR = os.path.join(
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
FLAGS = flags.FLAGS
class DataUtilTest(tf.test.TestCase):
def test_load_rgb_image(self):
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
image_data = data_util.load_image(image_path)
self.assertEqual(image_data.shape, (5184, 3456, 3))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,164 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common dataset for model training and evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from typing import Callable, Optional, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
class Dataset(object):
"""A generic dataset class for loading model training and evaluation dataset.
For each ML task, such as image classification, text classification etc., a
subclass can be derived from this class to provide task-specific data loading
utilities.
"""
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
"""Initializes Dataset class.
To build dataset from raw data, consider using the task specific utilities,
e.g. from_folder().
Args:
tf_dataset: A tf.data.Dataset object that contains a potentially large set
of elements, where each element is a pair of (input_data, target). The
`input_data` means the raw input data, like an image, a text etc., while
the `target` means the ground truth of the raw input data, e.g. the
classification label of the image etc.
size: The size of the dataset. tf.data.Dataset donesn't support a function
to get the length directly since it's lazy-loaded and may be infinite.
"""
self._dataset = tf_dataset
self._size = size
@property
def size(self) -> Optional[int]:
"""Returns the size of the dataset.
Note that this function may return None becuase the exact size of the
dataset isn't a necessary parameter to create an instance of this class,
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and this function can return an int representing the size of the dataset.
"""
return self._size
def gen_tf_dataset(self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None,
drop_remainder: bool = False) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation.
Args:
batch_size: An integer, the returned dataset will be batched by this size.
is_training: A boolean, when True, the returned dataset will be optionally
shuffled and repeated as an endless dataset.
shuffle: A boolean, when True, the returned dataset will be shuffled to
create randomness during model training.
preprocess: A function taking three arguments in order, feature, label and
boolean is_training.
drop_remainder: boolean, whether the finaly batch drops remainder.
Returns:
A TF dataset ready to be consumed by Keras model.
"""
dataset = self._dataset
if preprocess:
preprocess = functools.partial(preprocess, is_training=is_training)
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
if is_training:
if shuffle:
# Shuffle size should be bigger than the batch_size. Otherwise it's only
# shuffling within the batch, which equals to not having shuffle.
buffer_size = 3 * batch_size
# But since we are doing shuffle before repeat, it doesn't make sense to
# shuffle more than total available entries.
# TODO: Investigate if shuffling before / after repeat
# dataset can get a better performance?
# Shuffle after repeat will give a more randomized dataset and mix the
# epoch boundary: https://www.tensorflow.org/guide/data
if self._size:
buffer_size = min(self._size, buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# TODO: Consider converting dataset to distributed dataset
# here.
return dataset
def __len__(self):
"""Returns the number of element of the dataset."""
if self._size is not None:
return self._size
else:
return len(self._dataset)
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction)
def _split(self: _DatasetT, fraction: float,
*args) -> Tuple[_DatasetT, _DatasetT]:
"""Implementation for `split` method and returns sub-class instances.
Child DataLoader classes, if requires additional constructor arguments,
should implement their own `split` method by calling `_split` with all
arguments to the constructor.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
*args: additional arguments passed to the sub-class constructor.
Returns:
The splitted two sub datasets.
"""
assert (fraction > 0 and fraction < 1)
dataset = self._dataset
train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args)
test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args)
return trainset, testset

View File

@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import test_util
class DatasetTest(tf.test.TestCase):
def test_split(self):
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, 4)
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
self.assertIsInstance(train_data, ds.Dataset)
self.assertIsInstance(test_data, ds.Dataset)
for i, elem in enumerate(train_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
def test_len(self):
size = 4
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, size)
self.assertLen(data, size)
def test_gen_tf_dataset(self):
input_dim = 8
data = test_util.create_dataset(
data_size=2, input_shape=[input_dim], num_classes=2)
dataset = data.gen_tf_dataset()
self.assertLen(dataset, 2)
for (feature, label) in dataset:
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
dataset2 = data.gen_tf_dataset(batch_size=2)
self.assertLen(dataset2, 1)
for (feature, label) in dataset2:
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
self.assertEqual(dataset3.cardinality(), 1)
for (feature, label) in dataset3.take(10):
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
package(
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
mediapipe_files(srcs = ["test.jpg"])
filegroup(
name = "testdata",
srcs = ["test.jpg"],
)

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training models. Shared across tasks."""
import dataclasses
import tempfile
from typing import Optional
# TODO: Integrate this class into ImageClassifier and other tasks.
@dataclasses.dataclass
class BaseHParams:
"""Hyperparameters used for training models.
A common set of hyperparameters shared by the training jobs of all model
maker tasks.
Attributes:
learning_rate: The learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to
use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy
documentation for more details:
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
num_gpus: How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all
available GPUs.
tpu: The Cloud TPU to use for training. This should be either the name used
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
"""
# Parameters for train configuration
learning_rate: float
batch_size: int
epochs: int
steps_per_epoch: Optional[int] = None
# Dataset-related parameters
shuffle: bool = False
# Parameters for model / checkpoint files
export_dir: str = tempfile.mkdtemp()
# Parameters for hardware acceleration
distribution_strategy: str = 'off'
num_gpus: int = -1 # default value of -1 means use all available GPUs
tpu: str = ''

Some files were not shown because too many files have changed in this diff Show More