This commit is contained in:
Syntax 2022-10-24 22:52:17 +00:00 committed by GitHub
commit 0f4379cd64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
419 changed files with 25985 additions and 2387 deletions

View File

@ -53,7 +53,7 @@ RUN pip3 install wheel
RUN pip3 install future
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
RUN pip3 install six==1.14.0
RUN pip3 install tensorflow==2.2.0
RUN pip3 install tensorflow
RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python

View File

@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
}
```
## Graph Options
It is possible to specify a "graph options" protobuf for a MediaPipe graph
similar to the [`Calculator Options`](calculators.md#calculator-options)
protobuf specified for a MediaPipe calculator. These "graph options" can be
specified where a graph is invoked, and used to populate calculator options and
subgraph options within the graph.
In a CalculatorGraphConfig, graph options can be specified for a subgraph
exactly like calculator options, as shown below:
```
node {
calculator: "FlowLimiterCalculator"
input_stream: "image"
output_stream: "throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] {
max_in_flight: 1
}
}
}
node {
calculator: "FaceDetectionSubgraph"
input_stream: "IMAGE:throttled_image"
node_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {
tensor_width: 192
tensor_height: 192
}
}
}
```
In a CalculatorGraphConfig, graph options can be accepted and used to populate
calculator options, as shown below:
```
graph_options: {
[type.googleapis.com/mediapipe.FaceDetectionOptions] {}
}
node: {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:multi_backend_image"
node_options: {
[type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] {
keep_aspect_ratio: true
border_mode: BORDER_ZERO
}
}
option_value: "output_tensor_width:options/tensor_width"
option_value: "output_tensor_height:options/tensor_height"
}
node {
calculator: "InferenceCalculator"
node_options: {
[type.googleapis.com/mediapipe.InferenceCalculatorOptions] {}
}
option_value: "delegate:options/delegate"
option_value: "model_path:options/model_path"
}
```
In this example, the `FaceDetectionSubgraph` accepts graph option protobuf
`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field
values in the calculator options `ImageToTensorCalculatorOptions` and some field
values in the subgraph options `InferenceCalculatorOptions`. The field values
are defined using the `option_value:` syntax.
In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and
`option_value:` together define the option values for a calculator such as
`ImageToTensorCalculator`. The `node_options:` field defines a set of literal
constant values using the text protobuf syntax. Each `option_value:` field
defines the value for one protobuf field using information from the enclosing
graph, specifically from field values of the graph options of the enclosing
graph. In the example above, the `option_value:`
`"output_tensor_width:options/tensor_width"` defines the field
`ImageToTensorCalculatorOptions.output_tensor_width` using the value of
`FaceDetectionOptions.tensor_width`.
The syntax of `option_value:` is similar to the syntax of `input_stream:`. The
syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option
field and the RHS identifies a graph option field. More specifically, the LHS
and RHS each consists of a series of protobuf field names identifying nested
protobuf messages and fields separated by '/'. This is known as the "ProtoPath"
syntax. Nested messages that are referenced in the LHS or RHS must already be
defined in the enclosing protobuf in order to be traversed using
`option_value:`.
## Cycles
<!-- TODO: add discussion of PreviousLoopbackCalculator -->

View File

@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow
```bash
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt
```
This will open up your webcam as long as it is connected and on. Any errors

View File

@ -1410,3 +1410,45 @@ cc_library(
],
alwayslink = 1,
)
mediapipe_proto_library(
name = "bypass_calculator_proto",
srcs = ["bypass_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bypass_calculator",
srcs = ["bypass_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":bypass_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_test(
name = "bypass_calculator_test",
srcs = ["bypass_calculator_test.cc"],
deps = [
":bypass_calculator",
":pass_through_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:switch_container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -0,0 +1,161 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/calculators/core/bypass_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
using mediapipe::BypassCalculatorOptions;
// Defines a "bypass" channel to use in place of a disabled feature subgraph.
// By default, all inputs are discarded and all outputs are ignored.
// Certain input streams can be passed to corresponding output streams
// by specifying them in "pass_input_stream" and "pass_output_stream" options.
// All output streams are updated with timestamp bounds indicating completed
// output.
//
// Note that this calculator is designed for use as a contained_node in a
// SwitchContainer. For this reason, any input and output tags are accepted,
// and stream semantics are specified through BypassCalculatorOptions.
//
// Example config:
// node {
// calculator: "BypassCalculator"
// input_stream: "APPEARANCES:appearances_post_facenet"
// input_stream: "VIDEO:video_frame"
// input_stream: "FEATURE_CONFIG:feature_config"
// input_stream: "ENABLE:gaze_enabled"
// output_stream: "APPEARANCES:analyzed_appearances"
// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
// node_options: {
// [type.googleapis.com/mediapipe.BypassCalculatorOptions] {
// pass_input_stream: "APPEARANCES"
// pass_output_stream: "APPEARANCES"
// }
// }
// }
//
class BypassCalculator : public Node {
public:
static constexpr mediapipe::api2::Input<int>::Optional kNotNeeded{"N_N_"};
MEDIAPIPE_NODE_CONTRACT(kNotNeeded);
using IdMap = std::map<CollectionItemId, CollectionItemId>;
// Returns the map of passthrough input and output stream ids.
static absl::StatusOr<IdMap> GetPassMap(
const BypassCalculatorOptions& options, const tool::TagMap& input_map,
const tool::TagMap& output_map) {
IdMap result;
auto& input_streams = options.pass_input_stream();
auto& output_streams = options.pass_output_stream();
int size = std::min(input_streams.size(), output_streams.size());
for (int i = 0; i < size; ++i) {
std::pair<std::string, int> in_tag, out_tag;
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i),
&in_tag.first, &in_tag.second));
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i),
&out_tag.first, &out_tag.second));
auto input_id = input_map.GetId(in_tag.first, in_tag.second);
auto output_id = output_map.GetId(out_tag.first, out_tag.second);
result[input_id] = output_id;
}
return result;
}
// Identifies all specified streams as "Any" packet type.
// Identifies passthrough streams as "Same" packet type.
static absl::Status UpdateContract(CalculatorContract* cc) {
auto options = cc->Options<BypassCalculatorOptions>();
RET_CHECK_EQ(options.pass_input_stream().size(),
options.pass_output_stream().size());
ASSIGN_OR_RETURN(
auto pass_streams,
GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap()));
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams) {
pass_out.insert(entry.second);
cc->Inputs().Get(entry.first).SetAny();
cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first));
}
for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) {
if (pass_streams.count(id) == 0) {
cc->Inputs().Get(id).SetAny();
}
}
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetAny();
}
}
return absl::OkStatus();
}
// Saves the map of passthrough input and output stream ids.
absl::Status Open(CalculatorContext* cc) override {
auto options = cc->Options<BypassCalculatorOptions>();
ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(),
*cc->Outputs().TagMap()));
return absl::OkStatus();
}
// Copies packets between passthrough input and output streams.
// Updates timestamp bounds on all output streams.
absl::Status Process(CalculatorContext* cc) override {
std::set<CollectionItemId> pass_out;
for (auto entry : pass_streams_) {
pass_out.insert(entry.second);
auto& packet = cc->Inputs().Get(entry.first).Value();
if (packet.Timestamp() == cc->InputTimestamp()) {
cc->Outputs().Get(entry.first).AddPacket(packet);
}
}
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
if (pass_out.count(id) == 0) {
cc->Outputs().Get(id).SetNextTimestampBound(
std::max(cc->Outputs().Get(id).NextTimestampBound(), bound));
}
}
return absl::OkStatus();
}
// Close all output streams.
absl::Status Close(CalculatorContext* cc) override {
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).Close();
}
return absl::OkStatus();
}
private:
IdMap pass_streams_;
};
MEDIAPIPE_REGISTER_NODE(BypassCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,31 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BypassCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional BypassCalculatorOptions ext = 481259677;
}
// Names an input stream or streams to pass through, by "TAG:index".
repeated string pass_input_stream = 1;
// Names an output stream or streams to pass through, by "TAG:index".
repeated string pass_output_stream = 2;
}

View File

@ -0,0 +1,302 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// A graph with using a BypassCalculator to pass through and ignore
// most of its inputs and outputs.
constexpr char kTestGraphConfig1[] = R"pb(
type: "AppearancesPassThroughSubgraph"
input_stream: "APPEARANCES:appearances"
input_stream: "VIDEO:video_frame"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "APPEARANCES:passthrough_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output"
node {
calculator: "BypassCalculator"
input_stream: "PASS:appearances"
input_stream: "TRUNCATE:0:video_frame"
input_stream: "TRUNCATE:1:feature_config"
output_stream: "PASS:passthrough_appearances"
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "PASS"
pass_output_stream: "PASS"
}
}
}
)pb";
// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig2[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
}
}
}
)pb";
// A graph with using BypassCalculator as a do-nothing channel
// for input frames and appearances.
constexpr char kTestGraphConfig3[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "ENABLE:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: {
calculator: "BypassCalculator"
node_options: {
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
pass_input_stream: "APPEARANCES"
pass_output_stream: "APPEARANCES"
}
}
}
}
}
}
)pb";
// A graph with using BypassCalculator as a disabled-gate
// for input frames and appearances.
constexpr char kTestGraphConfig4[] = R"pb(
input_stream: "VIDEO_FULL_RES:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
input_stream: "GAZE_ENABLED:gaze_enabled"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
node {
calculator: "SwitchContainer"
input_stream: "ENABLE:gaze_enabled"
input_stream: "VIDEO:video_frame"
input_stream: "APPEARANCES:input_appearances"
input_stream: "FEATURE_CONFIG:feature_config"
output_stream: "VIDEO:video_frame_out"
output_stream: "APPEARANCES:analyzed_appearances"
output_stream: "FEATURE_CONFIG:feature_config_out"
options {
[mediapipe.SwitchContainerOptions.ext] {
contained_node: { calculator: "BypassCalculator" }
contained_node: { calculator: "PassThroughCalculator" }
}
}
}
)pb";
// Reports packet timestamp and string contents, or "<empty>"".
std::string DebugString(Packet p) {
return absl::StrCat(p.Timestamp().DebugString(), ":",
p.IsEmpty() ? "<empty>" : p.Get<std::string>());
}
// Shows a bypass subgraph that passes through one stream.
TEST(BypassCalculatorTest, SubgraphChannel) {
CalculatorGraphConfig config_1 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig1);
CalculatorGraphConfig config_2 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig2);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that passes through one stream.
TEST(BypassCalculatorTest, CalculatorChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig3);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> federated_gaze_output;
MP_ASSERT_OK(graph.ObserveOutputStream(
"federated_gaze_output",
[&](const Packet& p) {
federated_gaze_output.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows a BypassCalculator that discards all inputs when ENABLED is false.
TEST(BypassCalculatorTest, GatedChannel) {
CalculatorGraphConfig config_3 =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig4);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
std::vector<std::string> analyzed_appearances;
MP_ASSERT_OK(graph.ObserveOutputStream(
"analyzed_appearances",
[&](const Packet& p) {
analyzed_appearances.push_back(DebugString(p));
return absl::OkStatus();
},
true));
std::vector<std::string> video_frame;
MP_ASSERT_OK(graph.ObserveOutputStream(
"video_frame_out",
[&](const Packet& p) {
video_frame.push_back(DebugString(p));
return absl::OkStatus();
},
true));
MP_ASSERT_OK(graph.StartRun({}));
// Close the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(false).At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 200.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Only timestamps arrive from the BypassCalculator.
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:<empty>"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>"));
// Open the gate.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"gaze_enabled", MakePacket<bool>(true).At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Send packets at timestamp 300.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_appearances", MakePacket<std::string>("a2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"video_frame", MakePacket<std::string>("v2").At(Timestamp(300))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"feature_config", MakePacket<std::string>("f2").At(Timestamp(300))));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Packets arrive from the PassThroughCalculator.
EXPECT_THAT(analyzed_appearances,
testing::ElementsAre("200:<empty>", "300:a2"));
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>", "300:v2"));
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace
} // namespace mediapipe

View File

@ -209,11 +209,18 @@ cc_library(
alwayslink = 1,
)
mediapipe_proto_library(
name = "rotation_mode_proto",
srcs = ["rotation_mode.proto"],
visibility = ["//visibility:public"],
)
mediapipe_proto_library(
name = "image_transformation_calculator_proto",
srcs = ["image_transformation_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
":rotation_mode_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/gpu:scale_mode_proto",
@ -238,6 +245,7 @@ cc_library(
}),
visibility = ["//visibility:public"],
deps = [
":rotation_mode_cc_proto",
":image_transformation_calculator_cc_proto",
"//mediapipe/framework:packet",
"//mediapipe/framework:timestamp",

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
#include "mediapipe/calculators/image/rotation_mode.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"

View File

@ -16,20 +16,10 @@ syntax = "proto2";
package mediapipe;
import "mediapipe/calculators/image/rotation_mode.proto";
import "mediapipe/framework/calculator.proto";
import "mediapipe/gpu/scale_mode.proto";
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}
message ImageTransformationCalculatorOptions {
extend CalculatorOptions {
optional ImageTransformationCalculatorOptions ext = 251952830;

View File

@ -0,0 +1,31 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
option java_package = "com.google.mediapipe.calculator.proto";
option java_outer_classname = "RotationModeProto";
// Counterclockwise rotation.
message RotationMode {
enum Mode {
UNKNOWN = 0;
ROTATION_0 = 1;
ROTATION_90 = 2;
ROTATION_180 = 3;
ROTATION_270 = 4;
}
}

View File

@ -161,6 +161,193 @@ cc_test(
],
)
mediapipe_proto_library(
name = "bert_preprocessor_calculator_proto",
srcs = ["bert_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "bert_preprocessor_calculator",
srcs = ["bert_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":bert_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "bert_preprocessor_calculator_test",
srcs = ["bert_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":bert_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library(
name = "regex_preprocessor_calculator_proto",
srcs = ["regex_preprocessor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "regex_preprocessor_calculator",
srcs = ["regex_preprocessor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":regex_preprocessor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer",
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "regex_preprocessor_calculator_test",
srcs = ["regex_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
linkopts = ["-ldl"],
deps = [
":regex_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:sink",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "text_to_tensor_calculator",
srcs = ["text_to_tensor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "text_to_tensor_calculator_test",
srcs = ["text_to_tensor_calculator_test.cc"],
deps = [
":text_to_tensor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_graph",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:options_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "universal_sentence_encoder_preprocessor_calculator",
srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "universal_sentence_encoder_preprocessor_calculator_test",
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
deps = [
":universal_sentence_encoder_preprocessor_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:options_map",
"//mediapipe/tasks/cc/core:utils",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)
mediapipe_proto_library(
name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"],
@ -320,6 +507,8 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)

View File

@ -0,0 +1,251 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/ascii.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kTokenizerProcessUnitIndex = 0;
constexpr absl::string_view kInputIdsTensorName = "ids";
constexpr absl::string_view kInputMasksTensorName = "mask";
constexpr absl::string_view kSegmentIdsTensorName = "segment_ids";
constexpr absl::string_view kClassifierToken = "[CLS]";
constexpr absl::string_view kSeparatorToken = "[SEP]";
// Preprocesses input text into three int32 input tensors for a BERT model using
// a tokenizer.
// The associated BERT model is expected to contain input tensors with names:
//
// Tensor | Metadata Name
// ---------------- | --------------
// IDs | "ids"
// Segment IDs | "segment_ids"
// Mask | "mask"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have names corresponding to the above
// metadata names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// This calculator is currently configured for the TextClassifier Task but it
// will eventually be generalized for other Text Tasks.
// TODO: Handle preprocessing for other Text Tasks too.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the BERT model. Used to determine the order of
// the three input Tensors for the BERT model and to extract the metadata to
// construct the tokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the BERT model:
// (1): the token ids of the tokenized input string. A classifier token
// ("[CLS]") will be prepended to the input tokens and a separator
// token ("[SEP]") will be appended to the input tokens.
// (2): the segment ids, which are all 0 for now but will have different
// values to distinguish between different sentences in the input
// text for other Text tasks.
// (3): the input mask ids, which are 1 at each of the input token indices
// and 0 elsewhere.
// The Tensors will have size equal to the max sequence length for the BERT
// model.
//
// Example:
// node {
// calculator: "BertPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.BertPreprocessorCalculatorOptions.ext] {
// bert_max_seq_len: 128
// }
// }
// }
class BertPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
// The max sequence length accepted by the BERT model.
int bert_max_seq_len_ = 2;
// Indices of the three input tensors for the BERT model. They should form the
// set {0, 1, 2}.
int input_ids_tensor_index_ = 0;
int segment_ids_tensor_index_ = 1;
int input_masks_tensor_index_ = 2;
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
// Processes the `input_tokens` to generate the three input tensors for the
// BERT model.
std::vector<Tensor> GenerateInputTensors(
const std::vector<std::string>& input_tokens);
};
absl::Status BertPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
RET_CHECK_GE(options.bert_max_seq_len(), 2)
<< "bert_max_seq_len must be at least 2";
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::ProcessUnit* tokenizer_metadata =
metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex);
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateTokenizerFromProcessUnit(
tokenizer_metadata, metadata_extractor));
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
input_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputIdsTensorName);
segment_ids_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kSegmentIdsTensorName);
input_masks_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kInputMasksTensorName);
absl::flat_hash_set<int> tensor_indices = {input_ids_tensor_index_,
segment_ids_tensor_index_,
input_masks_tensor_index_};
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
input_ids_tensor_index_, segment_ids_tensor_index_,
input_masks_tensor_index_));
}
const auto& options =
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
bert_max_seq_len_ = options.bert_max_seq_len();
return absl::OkStatus();
}
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
kTensorsOut(cc).Send(
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
return absl::OkStatus();
}
std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
absl::string_view input_text) {
std::string processed_input = std::string(input_text);
absl::AsciiStrToLower(&processed_input);
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(processed_input);
// Offset by 2 to account for [CLS] and [SEP]
int input_tokens_size =
std::min(bert_max_seq_len_,
static_cast<int>(tokenizer_result.subwords.size()) + 2);
std::vector<std::string> input_tokens;
input_tokens.reserve(input_tokens_size);
input_tokens.push_back(std::string(kClassifierToken));
for (int i = 0; i < input_tokens_size - 2; ++i) {
input_tokens.push_back(std::move(tokenizer_result.subwords[i]));
}
input_tokens.push_back(std::string(kSeparatorToken));
return input_tokens;
}
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
const std::vector<std::string>& input_tokens) {
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
// Convert tokens back into ids and set mask
for (int i = 0; i < input_tokens.size(); ++i) {
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_masks[i] = 1;
}
// |<--------bert_max_seq_len_--------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForBert);
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
}
std::memcpy(input_tensors[input_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_ids.data(), input_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[segment_ids_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
segment_ids.data(), segment_ids.size() * sizeof(int32_t));
std::memcpy(input_tensors[input_masks_tensor_index_]
.GetCpuWriteView()
.buffer<int32_t>(),
input_masks.data(), input_masks.size() * sizeof(int32_t));
return input_tensors;
}
MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message BertPreprocessorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional BertPreprocessorCalculatorOptions ext = 462509271;
}
// The maximum input sequence length for the calculator's BERT model.
optional int32 bert_max_seq_len = 1;
}

View File

@ -0,0 +1,154 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kNumInputTensorsForBert = 3;
constexpr int kBertMaxSeqLen = 128;
constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/bert_text_classifier.tflite";
absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
absl::string_view text, absl::string_view model_path) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "BertPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
options {
[mediapipe.BertPreprocessorCalculatorOptions.ext] {
bert_max_seq_len: $0
}
}
}
)",
kBertMaxSeqLen));
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer = tasks::core::LoadBinaryContent(model_path.data());
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != kNumInputTensorsForBert) {
return absl::InvalidArgumentError(
absl::Substitute("tensor_vec has size $0, expected $1",
tensor_vec.size(), kNumInputTensorsForBert));
}
std::vector<std::vector<int>> results;
for (int i = 0; i < kNumInputTensorsForBert; i++) {
const Tensor& tensor = tensor_vec[i];
if (tensor.element_type() != Tensor::ElementType::kInt32) {
return absl::InvalidArgumentError("Expected tensor element type kInt32");
}
auto* buffer = tensor.GetCpuReadView().buffer<int>();
std::vector<int> buffer_view(buffer, buffer + kBertMaxSeqLen);
results.push_back(buffer_view);
}
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
return results;
}
TEST(BertPreprocessorCalculatorTest, TextClassifierWithBertModel) {
std::vector<std::vector<int>> expected_result = {
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 102}};
// segment_ids
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
// input_masks
expected_result.push_back(std::vector(expected_result[0].size(), 1));
expected_result[2].resize(kBertMaxSeqLen);
// padding input_ids
expected_result[0].resize(kBertMaxSeqLen);
MP_ASSERT_OK_AND_ASSIGN(
std::vector<std::vector<int>> processed_tensor_values,
RunBertPreprocessorCalculator(
"it's a charming and often affecting journey", kTestModelPath));
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
TEST(BertPreprocessorCalculatorTest, LongInput) {
std::stringstream long_input;
long_input
<< "it's a charming and often affecting journey and this is a long";
for (int i = 0; i < kBertMaxSeqLen; ++i) {
long_input << " long";
}
long_input << " movie review";
std::vector<std::vector<int>> expected_result = {
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1998, 2023,
2003, 1037}};
// "long" id
expected_result[0].resize(kBertMaxSeqLen - 1, 2146);
// "[SEP]" id
expected_result[0].push_back(102);
// segment_ids
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
// input_masks
expected_result.push_back(std::vector(kBertMaxSeqLen, 1));
MP_ASSERT_OK_AND_ASSIGN(
std::vector<std::vector<int>> processed_tensor_values,
RunBertPreprocessorCalculator(long_input.str(), kTestModelPath));
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
} // namespace
} // namespace mediapipe

View File

@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node {
}
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
const Size size{image->width(), image->height()};
RotatedRect roi = GetRoi(size.width, size.height, norm_rect);
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
options_.output_tensor_height(),
options_.keep_aspect_ratio(), &roi));
@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node {
}
if (kOutMatrix(cc).IsConnected()) {
std::array<float, 16> matrix;
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height,
/*flip_horizontaly=*/false,
&matrix);
GetRotatedSubRectToRectTransformMatrix(
roi, image->width(), image->height(),
/*flip_horizontaly=*/false, &matrix);
kOutMatrix(cc).Send(std::move(matrix));
}
// Lazy initialization of the GPU or CPU converter.
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
ASSIGN_OR_RETURN(Tensor tensor,
(image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, {output_width_, output_height_},
range_min_, range_max_));
Tensor::ElementType output_tensor_type =
GetOutputTensorType(image->UsesGpu());
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
GetNumOutputChannels(*image)});
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
->Convert(*image, roi, range_min_, range_max_,
/*tensor_buffer_offset=*/0, tensor));
auto result = std::make_unique<std::vector<Tensor>>();
result->push_back(std::move(tensor));
@ -292,15 +295,31 @@ class ImageToTensorCalculator : public Node {
}
}
Tensor::ElementType GetOutputTensorType() {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
if (!uses_gpu) {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
// Always use float32 when GPU is enabled.
return Tensor::ElementType::kFloat32;
}
int GetNumOutputChannels(const Image& image) {
#if !MEDIAPIPE_DISABLE_GPU
#if MEDIAPIPE_METAL_ENABLED
if (image.UsesGpu()) {
return 4;
}
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
// All of the processors except for Metal expect 3 channels.
return 3;
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node {
#if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN(
cpu_converter_,
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
CreateOpenCvConverter(cc, GetBorderMode(),
GetOutputTensorType(/*uses_gpu=*/false)));
#else
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined.";

View File

@ -42,13 +42,16 @@ class ImageToTensorConverter {
// @image contains image to extract from.
// @roi describes region of interest within the image to extract (absolute
// values).
// @output_dims dimensions of output tensor.
// @range_min/max describes output tensor range image pixels should converted
// to.
virtual absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims,
float range_min, float range_max) = 0;
// @tensor_buffer_offset an inteter representing the offset of the tensor
// buffer the result should be written to.
// @output_tensor a tensor with pre-defined shape. The "Convert" is
// responsible of populating the content into the output tensor.
virtual absl::Status Convert(const mediapipe::Image& input,
const RotatedRect& roi, float range_min,
float range_max, int tensor_buffer_offset,
Tensor& output_tensor) = 0;
};
} // namespace mediapipe

View File

@ -264,57 +264,58 @@ class GlProcessor : public ImageToTensorConverter {
});
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
"Unsupported format: ", static_cast<uint32_t>(input.format())));
}
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3;
Tensor tensor(Tensor::ElementType::kFloat32,
{1, output_dims.height, output_dims.width, kNumChannels});
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
[this, &output_tensor, &input, &roi, &output_shape, range_min,
range_max, tensor_buffer_offset]() -> absl::Status {
const int input_num_channels = input.channels();
auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(),
input_num_channels == 4 ? GL_RGB : GL_RGBA,
source_texture.width() * source_texture.height() *
input_num_channels * sizeof(uint8_t),
/*layer=*/0,
/*owned=*/false);
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi,
&output_dims, range_min,
range_max]() -> absl::Status {
constexpr int kRgbaNumChannels = 4;
auto source_texture = gl_helper_.CreateSourceTexture(input);
tflite::gpu::gl::GlTexture input_texture(
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
source_texture.width() * source_texture.height() * kRgbaNumChannels *
sizeof(uint8_t),
/*layer=*/0,
/*owned=*/false);
constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN(auto transform,
GetValueRangeTransformation(kInputImageRangeMin,
kInputImageRangeMax,
range_min, range_max));
constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN(
auto transform,
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
range_min, range_max));
const int output_size = output_tensor.bytes() / output_shape.dims[0];
auto buffer_view = output_tensor.GetOpenGlBufferWriteView();
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
buffer_view.name(), output_size,
/*offset=*/tensor_buffer_offset,
/*has_ownership=*/false);
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()),
roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_queue_.get(), &output));
auto buffer_view = tensor.GetOpenGlBufferWriteView();
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
buffer_view.name(), tensor.bytes(),
/*offset=*/0,
/*has_ownership=*/false);
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
input_texture,
tflite::gpu::HW(source_texture.height(), source_texture.width()), roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width),
command_queue_.get(), &output));
return absl::OkStatus();
}));
return absl::OkStatus();
}));
return tensor;
return absl::OkStatus();
}
~GlProcessor() override {
@ -326,6 +327,17 @@ class GlProcessor : public ImageToTensorConverter {
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
std::unique_ptr<SubRectExtractorGl> extractor_;
mediapipe::GlCalculatorHelper gl_helper_;

View File

@ -168,26 +168,26 @@ class GlProcessor : public ImageToTensorConverter {
});
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
"Unsupported format: ", static_cast<uint32_t>(input.format())));
}
// TODO: support tensor_buffer_offset > 0 scenario.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3;
Tensor tensor(
Tensor::ElementType::kFloat32,
Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels});
MP_RETURN_IF_ERROR(
gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims,
range_min, range_max]() -> absl::Status {
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
[this, &output_tensor, &input, &roi, &output_shape, range_min,
range_max]() -> absl::Status {
auto input_texture = gl_helper_.CreateSourceTexture(input);
constexpr float kInputImageRangeMin = 0.0f;
@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin,
kInputImageRangeMax,
range_min, range_max));
auto tensor_view = tensor.GetOpenGlTexture2dWriteView();
auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
/*flip_horizontaly=*/false,
transform.scale, transform.offset,
output_dims, &tensor_view));
output_shape, &tensor_view));
return absl::OkStatus();
}));
return tensor;
return absl::OkStatus();
}
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
const RotatedRect& sub_rect,
bool flip_horizontaly, float alpha, float beta,
const Size& output_dims,
const Tensor::Shape& output_shape,
Tensor::OpenGlTexture2dView* output) {
const int output_height = output_shape.dims[1];
const int output_width = output_shape.dims[2];
std::array<float, 16> transform_mat;
glDisable(GL_DEPTH_TEST);
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
glViewport(0, 0, output_dims.width, output_dims.height);
glViewport(0, 0, output_width, output_height);
glActiveTexture(GL_TEXTURE0);
glBindTexture(GL_TEXTURE_2D, output->name());
@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter {
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
mediapipe::GlCalculatorHelper gl_helper_;
bool use_custom_zero_border_ = false;
BorderMode border_mode_ = BorderMode::kReplicate;

View File

@ -262,7 +262,6 @@ class SubRectExtractorMetal {
RET_CHECK(pipeline_state != nil);
std::string output_type_def;
MTLPixelFormat pixel_format;
switch (output_format) {
case OutputFormat::kF16C4:
output_type_def = R"(
@ -348,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter {
return absl::OkStatus();
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
@ -359,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter {
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
}
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
@autoreleasepool {
id<MTLTexture> texture =
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
constexpr int kNumChannels = 4;
Tensor tensor(Tensor::ElementType::kFloat32,
Tensor::Shape{1, output_dims.height, output_dims.width,
kNumChannels});
constexpr float kInputImageRangeMin = 0.0f;
constexpr float kInputImageRangeMax = 1.0f;
ASSIGN_OR_RETURN(
@ -377,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter {
range_min, range_max));
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer);
const auto& buffer_view =
output_tensor.GetMtlBufferWriteView(command_buffer);
MP_RETURN_IF_ERROR(extractor_->Execute(
texture, roi,
/*flip_horizontaly=*/false, transform.scale, transform.offset,
tflite::gpu::HW(output_dims.height, output_dims.width),
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
command_buffer, buffer_view.buffer()));
[command_buffer commit];
return tensor;
return absl::OkStatus();
}
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 4)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
MPPMetalHelper* metal_helper_ = nil;
std::unique_ptr<SubRectExtractorMetal> extractor_;
};

View File

@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter {
}
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
float range_min, float range_max,
int tensor_buffer_offset,
Tensor& output_tensor) override {
if (input.image_format() != mediapipe::ImageFormat::SRGB &&
input.image_format() != mediapipe::ImageFormat::SRGBA) {
return InvalidArgumentError(
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
static_cast<uint32_t>(input.image_format())));
}
auto src = mediapipe::formats::MatView(&input);
// TODO: Remove the check once tensor_buffer_offset > 0 is
// supported.
RET_CHECK_EQ(tensor_buffer_offset, 0)
<< "The non-zero tensor_buffer_offset input is not supported yet.";
const auto& output_shape = output_tensor.shape();
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
constexpr int kNumChannels = 3;
Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height,
output_dims.width, kNumChannels});
auto buffer_view = tensor.GetCpuWriteView();
const int output_height = output_shape.dims[1];
const int output_width = output_shape.dims[2];
const int output_channels = output_shape.dims[3];
auto buffer_view = output_tensor.GetCpuWriteView();
cv::Mat dst;
switch (tensor_type_) {
case Tensor::ElementType::kInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<int8>());
break;
case Tensor::ElementType::kFloat32:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<float>());
break;
case Tensor::ElementType::kUInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
dst = cv::Mat(output_height, output_width, mat_type_,
buffer_view.buffer<uint8>());
break;
default:
@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
cv::Mat src_points;
cv::boxPoints(rotated_rect, src_points);
const float dst_width = output_dims.width;
const float dst_height = output_dims.height;
const float dst_width = output_width;
const float dst_height = output_height;
/* clang-format off */
float dst_corners[8] = {0.0f, dst_height,
0.0f, 0.0f,
@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
dst_width, dst_height};
/* clang-format on */
auto src = mediapipe::formats::MatView(&input);
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
cv::Mat projection_matrix =
cv::getPerspectiveTransform(src_points, dst_points);
@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
/*flags=*/cv::INTER_LINEAR,
/*borderMode=*/border_mode_);
if (transformed.channels() > kNumChannels) {
if (transformed.channels() > output_channels) {
cv::Mat proper_channels_mat;
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
transformed = proper_channels_mat;
@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter {
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
range_min, range_max));
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
return tensor;
return absl::OkStatus();
}
private:
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
RET_CHECK_EQ(output_shape.dims.size(), 4)
<< "Wrong output dims size: " << output_shape.dims.size();
RET_CHECK_EQ(output_shape.dims[0], 1)
<< "Handling batch dimension not equal to 1 is not implemented in this "
"converter.";
RET_CHECK_EQ(output_shape.dims[3], 3)
<< "Wrong output channel: " << output_shape.dims[3];
return absl::OkStatus();
}
enum cv::BorderTypes border_mode_;
Tensor::ElementType tensor_type_;
int mat_type_;

View File

@ -26,6 +26,8 @@
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe {
namespace api2 {
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
std::vector<Tensor>& output_tensors) {
return gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status {
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
// Explicitly copy input.
for (int i = 0; i < input_tensors.size(); ++i) {
glBindBuffer(GL_COPY_READ_BUFFER,
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
}
// Run inference.
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
{
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
output_tensors.reserve(output_size_);
for (int i = 0; i < output_size_; ++i) {

View File

@ -32,6 +32,8 @@
#include "mediapipe/util/android/file/base/helpers.h"
#endif // MEDIAPIPE_ANDROID
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
namespace mediapipe {
namespace api2 {
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
absl::StatusOr<std::vector<Tensor>> Process(
const std::vector<Tensor>& input_tensors);
CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
absl::Status Close();
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
absl::StatusOr<std::vector<Tensor>>
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
const std::vector<Tensor>& input_tensors) {
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
std::vector<Tensor> output_tensors;
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status {
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].GetOpenGlBufferReadView().name(), i));
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
output_tensors.back().GetOpenGlBufferWriteView().name(), i));
}
// Run inference.
return tflite_gpu_runner_->Invoke();
{
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
return tflite_gpu_runner_->Invoke();
}
}));
return output_tensors;
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
ASSIGN_OR_RETURN(*output_tensors,
gpu_inference_runner_->Process(input_tensors));
gpu_inference_runner_->Process(cc, input_tensors));
kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus();

View File

@ -224,9 +224,6 @@ absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
void InferenceCalculatorMetalImpl::AddDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>();
// Configure and create the delegate.
TFLGpuDelegateOptions options;
// `enable_quantization` enables the run of sparse models i.e. the models with

View File

@ -21,8 +21,10 @@
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/interpreter_builder.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe {
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
}
template <>
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
const char* input_tensor_buffer =
input_tensor.GetCpuReadView().buffer<char>();
tflite::DynamicBuffer dynamic_buffer;
dynamic_buffer.AddString(input_tensor_buffer,
input_tensor.shape().num_elements());
dynamic_buffer.WriteToTensorAsVector(
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
}
template <typename T>
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
int output_tensor_index,
@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
break;
}
case TfLiteType::kTfLiteUInt8: {
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
interpreter_.get(), i);
CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteInt8: {
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
interpreter_.get(), i);
CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteInt32: {
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteString: {
CopyTensorBufferToInterpreter<char>(input_tensors[i],
interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteBool:
// No current use-case for copying MediaPipe Tensors with bool type to
// TfLiteTensors.
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteBool:
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
Tensor::QuantizationParameters{1.0f, 0});
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
&output_tensors.back());
break;
case TfLiteType::kTfLiteString:
// No current use-case for copying TfLiteTensors with string type to
// MediaPipe Tensors.
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported output tensor type:",

View File

@ -0,0 +1,174 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h"
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
// Preprocesses input text into one int32 input tensor for a text model using
// a RegexTokenizer.
//
// Inputs:
// TEXT - std::string
// The input text.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the text model. Used to extract the metadata
// to construct the RegexTokenizer.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor which is the text model's input tensor.
// Depending on the tokenizer metadata, the tensor may start with
// the id of the tokenizer's <START> token. The following tensor values will
// be the ids of the tokens of the input text. Any out-of-vocab tokens will
// have the id of the <UNKNOWN> token. The tensor will be padded with the
// <PAD> token id to have size equal to the max sequence length for the text
// model.
//
// Example:
// node {
// calculator: "RegexPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// options {
// [mediapipe.RegexPreprocessorCalculatorOptions.ext] {
// max_seq_len: 256
// }
// }
// }
class RegexPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
std::unique_ptr<tasks::text::tokenizers::RegexTokenizer> tokenizer_;
// The max sequence length accepted by the text model.
int max_seq_len_ = 0;
};
absl::Status RegexPreprocessorCalculator::UpdateContract(
CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required";
RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive";
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
const tflite::TensorMetadata* tensor_metadata =
metadata_extractor->GetInputTensorMetadata(0);
if (tensor_metadata == nullptr) {
return absl::InvalidArgumentError("No tensor metadata found");
}
ASSIGN_OR_RETURN(
const auto* tokenizer_metadata,
metadata_extractor->FindFirstProcessUnit(
*tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions));
if (tokenizer_metadata == nullptr) {
return absl::InvalidArgumentError("No tokenizer metadata found");
}
const tflite::RegexTokenizerOptions* regex_tokenizer_options =
tokenizer_metadata->options_as<tflite::RegexTokenizerOptions>();
ASSIGN_OR_RETURN(tokenizer_,
tasks::text::tokenizers::CreateRegexTokenizerFromOptions(
regex_tokenizer_options, metadata_extractor));
const auto& options =
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
max_seq_len_ = options.max_seq_len();
return absl::OkStatus();
}
absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
tasks::text::tokenizers::TokenizerResult tokenizer_result =
tokenizer_->Tokenize(kTextIn(cc).Get());
int unknown_token_id = 0;
tokenizer_->GetUnknownToken(&unknown_token_id);
int pad_token_id = 0;
tokenizer_->GetPadToken(&pad_token_id);
std::vector<int> input_tokens(max_seq_len_, pad_token_id);
int start_token_id = 0;
int input_token_index = 0;
if (tokenizer_->GetStartToken(&start_token_id)) {
input_tokens[0] = start_token_id;
input_token_index = 1;
}
for (int i = 0; (i < tokenizer_result.subwords.size()) &&
(input_token_index < max_seq_len_);
++i, ++input_token_index) {
const std::string& token = tokenizer_result.subwords[i];
int token_id = 0;
if (tokenizer_->LookupId(token, &token_id)) {
input_tokens[input_token_index] = token_id;
} else {
input_tokens[input_token_index] = unknown_token_id;
}
}
// |<-------sentence_length-------->|
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's
// not found in the tokenizer vocab.
std::vector<Tensor> result;
result.push_back(
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
kTensorsOut(cc).Send(std::move(result));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,29 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message RegexPreprocessorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional RegexPreprocessorCalculatorOptions ext = 463716697;
}
// The maximum input sequence length for the calculator's text model.
optional int32 max_seq_len = 1;
}

View File

@ -0,0 +1,130 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/sink.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kMaxSeqLen = 256;
constexpr char kTestModelPath[] =
"mediapipe/tasks/testdata/text/"
"test_model_text_classifier_with_regex_tokenizer.tflite";
absl::StatusOr<std::vector<int>> RunRegexPreprocessorCalculator(
absl::string_view text) {
auto graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "RegexPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
options {
[mediapipe.RegexPreprocessorCalculatorOptions.ext] {
max_seq_len: $0
}
}
}
)pb",
kMaxSeqLen));
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer = tasks::core::LoadBinaryContent(kTestModelPath);
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected $1", tensor_vec.size(), 1));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kInt32) {
return absl::InvalidArgumentError("Expected tensor element type kInt32");
}
auto* buffer = tensor_vec[0].GetCpuReadView().buffer<int>();
std::vector<int> result(buffer, buffer + kMaxSeqLen);
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
return result;
}
TEST(RegexPreprocessorCalculatorTest, TextClassifierModel) {
MP_ASSERT_OK_AND_ASSIGN(
std::vector<int> processed_tensor_values,
RunRegexPreprocessorCalculator("This is the best movie Ive seen in "
"recent years. Strongly recommend it!"));
static const int expected_result[kMaxSeqLen] = {
1, 2, 9, 4, 118, 20, 2, 2, 110, 11, 1136, 153, 2, 386, 12};
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
TEST(RegexPreprocessorCalculatorTest, LongInput) {
std::stringstream long_input;
long_input << "This is the best";
for (int i = 0; i < kMaxSeqLen; ++i) {
long_input << " best";
}
long_input << "movie Ive seen in recent years. Strongly recommend it!";
MP_ASSERT_OK_AND_ASSIGN(std::vector<int> processed_tensor_values,
RunRegexPreprocessorCalculator(long_input.str()));
std::vector<int> expected_result = {1, 2, 9, 4, 118};
// "best" id
expected_result.resize(kMaxSeqLen, 118);
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
}
} // namespace
} // namespace mediapipe

View File

@ -296,7 +296,6 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) {
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
Tensor::Shape{1, height, width, channels});
#if MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TensorConverterCalculatorConvert";
id<MTLComputeCommandEncoder> compute_encoder =

View File

@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
case Tensor::ElementType::kInt8:
Dequantize<int8>(input_tensor, &output_tensors->back());
break;
case Tensor::ElementType::kBool:
Dequantize<bool>(input_tensor, &output_tensors->back());
break;
default:
return absl::InvalidArgumentError(absl::StrCat(
"Unsupported input tensor type: ", input_tensor.element_type()));

View File

@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
ValidateResult(GetOutput(), {-1.007874, 0, 1});
}
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
std::vector<bool> tensor = {true, false, true};
PushTensor(Tensor::ElementType::kBool, tensor,
Tensor::QuantizationParameters{1.0f, 0});
MP_ASSERT_OK(runner_.Run());
ValidateResult(GetOutput(), {1, 0, 1});
}
} // namespace
} // namespace mediapipe

View File

@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
}
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
const auto& input_tensors = *kInTensors(cc);
RET_CHECK_EQ(input_tensors.size(), 1);
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
auto raw_scores = view.buffer<float>();
auto classification_list = absl::make_unique<ClassificationList>();
if (options.has_tensor_index()) {
classification_list->set_tensor_index(options.tensor_index());
}
if (options.has_tensor_name()) {
classification_list->set_tensor_name(options.tensor_name());
}
if (is_binary_classification_) {
Classification* class_first = classification_list->add_classification();
Classification* class_second = classification_list->add_classification();

View File

@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions {
// that are not in the `allow_classes` field will be completely ignored.
// `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 allow_classes = 8 [packed = true];
// The optional index of the tensor these classifications originate from.
optional int32 tensor_index = 10;
// The optional name of the tensor these classifications originate from.
optional string tensor_name = 11;
}

View File

@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest,
}
}
TEST_F(TensorsToClassificationCalculatorTest,
CorrectOutputWithTensorNameAndIndex) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
tensor_index: 1
tensor_name: "foo"
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(3, classification_list.classification_size());
// Verify that the tensor_index and tensor_name fields are correctly set.
EXPECT_EQ(classification_list.tensor_index(), 1);
EXPECT_EQ(classification_list.tensor_name(), "foo");
}
TEST_F(TensorsToClassificationCalculatorTest,
ClassNameAllowlistWithLabelItems) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(

View File

@ -532,7 +532,6 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
detection_classes.data(),
output_detections));
#elif MEDIAPIPE_METAL_ENABLED
id<MTLDevice> device = gpu_helper_.mtlDevice;
if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);

View File

@ -67,9 +67,7 @@ absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
"tensor_vec has size $0, expected 1", tensor_vec.size()));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
Tensor::ElementType::kChar));
return absl::InvalidArgumentError("Expected tensor element type kChar");
}
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
return std::string(buffer, text.length());

View File

@ -0,0 +1,167 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <array>
#include <cstring>
#include <string>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
constexpr absl::string_view kResponseContextMetadataName = "res_context";
constexpr absl::string_view kResponseTextMetadataName = "res_text";
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
// Preprocesses input text into three kTfLiteString input tensors for a
// Universal Sentence Encoder (USE) model.
//
// The associated USE model is expected to contain input tensors with metadata
// names:
//
// Tensor | Metadata Name
// ---------------- | ------------------
// Query text | "inp_text"
// Response context | "res_context"
// Response text | "res_text"
//
// This calculator will return an error if the model does not have three input
// tensors or if the tensors do not have metadata names corresponding to the
// above names in some order. Additional details regarding these input
// tensors are given in the Calculator "Outputs" section below.
//
// Inputs:
// TEXT - std::string
// The text to be embedded.
// Side Inputs:
// METADATA_EXTRACTOR - ModelMetadataExtractor
// The metadata extractor for the USE model. Used to determine the order of
// the three input Tensors for the USE model.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing the three input Tensors for the USE model. The tensors
// fit a question-answering setting and store a query text, a response
// context, and a response text. This calculator will just be preprocessing
// a single input text that will be stored in the response text tensor. The
// query text and response context tensors will store empty strings.
//
// Example:
// node {
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
// input_stream: "TEXT:text"
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
// output_stream: "TENSORS:tensors"
// }
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
public:
static constexpr Input<std::string> kTextIn{"TEXT"};
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
"METADATA_EXTRACTOR"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
// Indices of the three input tensors for the USE model. They should form the
// set {0, 1, 2}.
int query_text_tensor_index_ = 0;
int response_context_tensor_index_ = 1;
int response_text_tensor_index_ = 2;
// Tensor shapes for the model's input tensors.
// The query text and response context tensors will only hold the empty
// string, so their tensors will have shape [0], but the Universal Sentence
// Encoder model's input signature requires them to be present. The response
// text tensor will store the embedding text and have shape
// [embedding_text_len].
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
};
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
CalculatorContext* cc) {
const ModelMetadataExtractor* metadata_extractor =
&kMetadataExtractorSideIn(cc).Get();
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
query_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kQueryTextMetadataName);
response_context_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseContextMetadataName);
response_text_tensor_index_ = FindTensorIndexByMetadataName(
input_tensors_metadata, kResponseTextMetadataName);
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
{query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_});
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
return absl::InvalidArgumentError(absl::Substitute(
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
query_text_tensor_index_, response_context_tensor_index_,
response_text_tensor_index_));
}
return absl::OkStatus();
}
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
CalculatorContext* cc) {
absl::string_view text = kTextIn(cc).Get();
const int text_len = static_cast<int>(text.length());
tensor_shapes_[response_text_tensor_index_] = text_len;
std::vector<Tensor> input_tensors;
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
input_tensors.push_back(
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
}
std::memcpy(
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_context_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
"", 0);
std::memcpy(input_tensors[response_text_tensor_index_]
.GetCpuWriteView()
.buffer<char>(),
text.data(), text_len * sizeof(char));
kTensorsOut(cc).Send(std::move(input_tensors));
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,109 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/options_map.h"
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
namespace mediapipe {
namespace {
using ::mediapipe::IsOkAndHolds;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::testing::ElementsAreArray;
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
constexpr absl::string_view kTestModelPath =
"mediapipe/tasks/testdata/text/"
"universal_sentence_encoder_qa_with_metadata.tflite";
absl::StatusOr<std::vector<std::string>>
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "text"
output_stream: "tensors"
node {
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
input_stream: "TEXT:text"
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
output_stream: "TENSORS:tensors"
}
)pb");
std::vector<Packet> output_packets;
tool::AddVectorSink("tensors", &graph_config, &output_packets);
std::string model_buffer =
tasks::core::LoadBinaryContent(kTestModelPath.data());
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
ModelMetadataExtractor::CreateFromModelBuffer(
model_buffer.data(), model_buffer.size()));
// Run the graph.
CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(
graph_config,
{{"metadata_extractor",
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
MP_RETURN_IF_ERROR(graph.StartRun({}));
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
"text", MakePacket<std::string>(text).At(Timestamp(0))));
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
if (output_packets.size() != 1) {
return absl::InvalidArgumentError(absl::Substitute(
"output_packets has size $0, expected 1", output_packets.size()));
}
const std::vector<Tensor>& tensor_vec =
output_packets[0].Get<std::vector<Tensor>>();
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
return absl::InvalidArgumentError(absl::Substitute(
"tensor_vec has size $0, expected $1", tensor_vec.size(),
kNumInputTensorsForUniversalSentenceEncoder));
}
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
return absl::InvalidArgumentError("Expected tensor element type kChar");
}
std::vector<std::string> results;
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
results.push_back(
{tensor_vec[i].GetCpuReadView().buffer<char>(),
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
}
return results;
}
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
ASSERT_THAT(
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
}
} // namespace
} // namespace mediapipe

View File

@ -331,6 +331,7 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework",
],
alwayslink = 1,

View File

@ -499,7 +499,6 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
gpu_data_out_ = absl::make_unique<GPUData>();
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
const bool include_alpha = (max_num_channels_ == 4);
const bool single_channel = (max_num_channels_ == 1);
if (!(format == mediapipe::ImageFormat::GRAY8 ||
format == mediapipe::ImageFormat::SRGB ||
format == mediapipe::ImageFormat::SRGBA))
@ -509,6 +508,7 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
#if MEDIAPIPE_TFLITE_GL_INFERENCE
const bool single_channel = (max_num_channels_ == 1);
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &include_alpha, &input, &single_channel]() -> absl::Status {
// Device memory.

View File

@ -16,6 +16,7 @@
#include <memory>
#include <string>
#include "absl/status/status.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h"
@ -81,6 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
}
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
#ifdef ABSL_HAVE_MMAP
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>();
@ -89,6 +91,10 @@ class TfLiteModelCalculator : public CalculatorBase {
tflite::DefaultErrorReporter());
model = tflite::FlatBufferModel::BuildFromAllocation(
std::move(model_allocation), tflite::DefaultErrorReporter());
#else
return absl::FailedPreconditionError(
"Loading by file descriptor is not supported on this platform.");
#endif
}
RET_CHECK(model) << "Failed to load TfLite model from blob.";

View File

@ -143,9 +143,7 @@ mediapipe_proto_library(
cc_library(
name = "packet_frequency_calculator",
srcs = ["packet_frequency_calculator.cc"],
visibility = [
"//visibility:public",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/util:packet_frequency_calculator_cc_proto",
"//mediapipe/calculators/util:packet_frequency_cc_proto",
@ -190,9 +188,7 @@ cc_test(
cc_library(
name = "packet_latency_calculator",
srcs = ["packet_latency_calculator.cc"],
visibility = [
"//visibility:public",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/calculators/util:latency_cc_proto",
"//mediapipe/calculators/util:packet_latency_calculator_cc_proto",

View File

@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
text->set_left(label_left_px_);
text->set_baseline(label_baseline_px + i * label_height_px_);
text->set_font_face(options_.font_face());
if (options_.outline_thickness() > 0) {
text->set_outline_thickness(options_.outline_thickness());
if (options_.outline_color_size() > 0) {
*(text->mutable_outline_color()) =
options_.outline_color(i % options_.outline_color_size());
} else {
text->mutable_outline_color()->set_r(0);
text->mutable_outline_color()->set_g(0);
text->mutable_outline_color()->set_b(0);
}
}
}
cc->Outputs()
.Tag(kRenderDataTag)

View File

@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions {
// Thickness for drawing the label(s).
optional double thickness = 2 [default = 2];
// Color of outline around each character, if any. One per label, as with
// color attribute.
repeated Color outline_color = 12;
// Thickness of outline around each character.
optional double outline_thickness = 11;
// The font height in absolute pixels.
optional int32 font_height_px = 3 [default = 50];

View File

@ -18,7 +18,7 @@ import android.content.ClipDescription;
import android.content.Context;
import android.net.Uri;
import android.os.Bundle;
import androidx.appcompat.widget.AppCompatEditText;
import android.support.v7.widget.AppCompatEditText;
import android.util.AttributeSet;
import android.util.Log;
import android.view.inputmethod.EditorInfo;

View File

@ -1685,10 +1685,3 @@ cc_test(
"@com_google_absl//absl/strings:str_format",
],
)
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -14,15 +14,10 @@ cc_library(
name = "builder",
hdrs = ["builder.h"],
deps = [
":const_str",
":contract",
":node",
":packet",
":port",
"//mediapipe/framework:calculator_base",
"//mediapipe/framework:calculator_contract",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)

View File

@ -5,12 +5,7 @@
#include <type_traits>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/const_str.h"
#include "mediapipe/framework/api2/contract.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_base.h"
#include "mediapipe/framework/calculator_contract.h"
@ -112,6 +107,17 @@ class MultiPort : public Single {
std::vector<std::unique_ptr<Base>>& vec_;
};
namespace internal_builder {
template <typename T, typename U>
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>>;
} // namespace internal_builder
template <bool IsSide, typename T = internal::Generic>
class SourceImpl;
// These classes wrap references to the underlying source/destination
// endpoints, adding type information and the user-visible API.
template <bool IsSide, typename T = internal::Generic>
@ -122,16 +128,21 @@ class DestinationImpl {
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
template <typename U,
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
DestinationImpl<IsSide, U> Cast() {
return DestinationImpl<IsSide, U>(&base_);
}
private:
DestinationBase& base_;
template <bool Source_IsSide, typename Source_T>
friend class SourceImpl;
};
template <bool IsSide, typename T>
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> {
public:
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort;
};
template <bool IsSide, typename T = internal::Generic>
class SourceImpl {
public:
using Base = SourceBase;
@ -171,12 +182,8 @@ class SourceImpl {
return AddTarget(dest);
}
template <typename U>
struct AllowCast
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
!std::is_same_v<T, U>> {};
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
template <typename U,
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
SourceImpl<IsSide, U> Cast() {
return SourceImpl<IsSide, U>(base_);
}
@ -186,12 +193,6 @@ class SourceImpl {
SourceBase* base_;
};
template <bool IsSide, typename T>
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
public:
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
};
// A source and a destination correspond to an output/input stream on a node,
// and a side source and side destination correspond to an output/input side
// packet.
@ -201,20 +202,20 @@ class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
template <typename T = internal::Generic>
using Source = SourceImpl<false, T>;
template <typename T = internal::Generic>
using MultiSource = MultiSourceImpl<false, T>;
using MultiSource = MultiPort<Source<T>>;
template <typename T = internal::Generic>
using SideSource = SourceImpl<true, T>;
template <typename T = internal::Generic>
using MultiSideSource = MultiSourceImpl<true, T>;
using MultiSideSource = MultiPort<SideSource<T>>;
template <typename T = internal::Generic>
using Destination = DestinationImpl<false, T>;
template <typename T = internal::Generic>
using SideDestination = DestinationImpl<true, T>;
template <typename T = internal::Generic>
using MultiDestination = MultiDestinationImpl<false, T>;
using MultiDestination = MultiPort<Destination<T>>;
template <typename T = internal::Generic>
using MultiSideDestination = MultiDestinationImpl<true, T>;
using MultiSideDestination = MultiPort<SideDestination<T>>;
class NodeBase {
public:
@ -439,8 +440,9 @@ class Graph {
// Creates a node of a specific type. Should be used for pure interfaces,
// which do not have a built-in type string.
template <class Calc>
Node<Calc>& AddNode(const std::string& type) {
auto node = std::make_unique<Node<Calc>>(type);
Node<Calc>& AddNode(absl::string_view type) {
auto node =
std::make_unique<Node<Calc>>(std::string(type.data(), type.size()));
auto node_p = node.get();
nodes_.emplace_back(std::move(node));
return *node_p;
@ -448,16 +450,18 @@ class Graph {
// Creates a generic node, with no compile-time checking of inputs and
// outputs. This can be used for calculators whose contract is not visible.
GenericNode& AddNode(const std::string& type) {
auto node = std::make_unique<GenericNode>(type);
GenericNode& AddNode(absl::string_view type) {
auto node =
std::make_unique<GenericNode>(std::string(type.data(), type.size()));
auto node_p = node.get();
nodes_.emplace_back(std::move(node));
return *node_p;
}
// For legacy PacketGenerators.
PacketGenerator& AddPacketGenerator(const std::string& type) {
auto node = std::make_unique<PacketGenerator>(type);
PacketGenerator& AddPacketGenerator(absl::string_view type) {
auto node = std::make_unique<PacketGenerator>(
std::string(type.data(), type.size()));
auto node_p = node.get();
packet_gens_.emplace_back(std::move(node));
return *node_p;

View File

@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output.SetName("any_type_output");
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
output_stream: "ANY_OUTPUT:any_type_output"
}
input_stream: "GRAPH_ANY_INPUT:__stream_0"
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
)pb");
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
}

View File

@ -185,7 +185,7 @@ class CalculatorBaseFactory {
// Functions for checking that the calculator has the required GetContract.
template <class T>
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
typedef absl::Status (*GetContractType)(CalculatorContract * cc);
typedef absl::Status (*GetContractType)(CalculatorContract* cc);
return std::is_same<decltype(&T::GetContract), GetContractType>::value;
}
template <class T>

View File

@ -133,7 +133,12 @@ message GraphTrace {
TPU_TASK = 13;
GPU_CALIBRATION = 14;
PACKET_QUEUED = 15;
GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17;
}
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
// )
// The timing for one packet set being processed at one caclulator node.
message CalculatorTrace {

View File

@ -334,13 +334,6 @@ mediapipe_register_type(
deps = [":landmark_cc_proto"],
)
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)
cc_library(
name = "image",
srcs = ["image.cc"],
@ -469,6 +462,10 @@ cc_library(
],
"//conditions:default": [],
}),
defines = select({
"//mediapipe/framework:android_no_jni": ["MEDIAPIPE_NO_JNI"],
"//conditions:default": [],
}),
linkopts = select({
"//mediapipe:ios": [
"-framework CoreVideo",

View File

@ -33,10 +33,3 @@ mediapipe_proto_library(
srcs = ["rasterization.proto"],
visibility = ["//visibility:public"],
)
# Expose the proto source files for building mediapipe AAR.
filegroup(
name = "protos_src",
srcs = glob(["*.proto"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -37,6 +37,10 @@ message Classification {
// Group of Classification protos.
message ClassificationList {
repeated Classification classification = 1;
// Optional index of the tensor that produced these classifications.
optional int32 tensor_index = 2;
// Optional name of the tensor that produced these classifications.
optional string tensor_name = 3;
}
// Group of ClassificationList protos.

View File

@ -31,11 +31,12 @@
#if MEDIAPIPE_METAL_ENABLED
#import <Metal/Metal.h>
#endif // MEDIAPIPE_METAL_ENABLED
#ifndef MEDIAPIPE_NO_JNI
#if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
#define MEDIAPIPE_TENSOR_USE_AHWB 1
#endif // __ANDROID_API__ >= 26 ||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
#endif // MEDIAPIPE_NO_JNI
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include <android/hardware_buffer.h>
@ -43,7 +44,6 @@
#include "third_party/GL/gl/include/EGL/egl.h"
#include "third_party/GL/gl/include/EGL/eglext.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_context.h"
@ -97,8 +97,8 @@ class Tensor {
kUInt8,
kInt8,
kInt32,
// TODO: Update the inference runner to handle kTfLiteString.
kChar
kChar,
kBool
};
struct Shape {
Shape() = default;
@ -330,6 +330,8 @@ class Tensor {
return sizeof(int32_t);
case ElementType::kChar:
return sizeof(char);
case ElementType::kBool:
return sizeof(bool);
}
}
int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
// EGLSync is failed. Use another synchronization method.
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
glFinish();
gl_context_->Run([]() { glFinish(); });
} else if (valid_ & kValidAHardwareBuffer) {
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
"completion function to be set";

View File

@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
}
TEST(Cpu, TestMemoryAllocation) {

View File

@ -109,6 +109,11 @@ struct TraceEvent {
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
// )
};
// Packet trace log buffer.

View File

@ -64,7 +64,7 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
std::vector<TraceEventType> basic_types = {
{TraceEvent::UNKNOWN, "An uninitialized trace-event."},
{TraceEvent::OPEN, "A call to Calculator::Open.", true, true},
{TraceEvent::PROCESS, "A call to Calculator::Open.", true, true},
{TraceEvent::PROCESS, "A call to Calculator::Process.", true, true},
{TraceEvent::CLOSE, "A call to Calculator::Close.", true, true},
{TraceEvent::NOT_READY, "A calculator cannot process packets yet."},

View File

@ -150,7 +150,7 @@ cc_library(
name = "executor_util",
srcs = ["executor_util.cc"],
hdrs = ["executor_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:mediapipe_options_cc_proto",

View File

@ -378,8 +378,11 @@ cc_library(
],
}),
deps = [
":gl_texture_buffer",
":gpu_buffer_format",
":gpu_buffer_storage",
":image_frame_view",
"//mediapipe/framework/formats:image_frame",
"@com_google_absl//absl/strings:str_format",
],
)
@ -1050,7 +1053,7 @@ objc_library(
alwayslink = 1,
)
MIN_IOS_VERSION = "9.0" # For thread_local.
MIN_IOS_VERSION = "11.0"
test_suite(
name = "ios",

View File

@ -111,7 +111,8 @@ typedef CVOpenGLESTextureCacheRef CVTextureCacheType;
- (CVMetalTextureCacheRef)mtlTextureCache {
@synchronized(self) {
if (!_mtlTextureCache) {
CVReturn err = CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
CVReturn __unused err =
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err);
// TODO: register and flush metal caches too.
}

View File

@ -47,6 +47,8 @@ static void EglThreadExitCallback(void* key_value) {
// implementations, and should be considered as an undocumented vendor
// extension.
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
//
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
EGL_NO_SURFACE, EGL_NO_CONTEXT);
#endif

View File

@ -78,7 +78,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) {
16,
0};
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs];
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_2_1];
}
if (!pixel_format_) {
// On several Forge machines, the default config fails. For now let's do

View File

@ -144,14 +144,23 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
context](std::shared_ptr<GlSyncPoint> sync_token) {
CHECK_NE(name_, 0);
GLuint name_to_delete = name_;
context->RunWithoutWaiting([name_to_delete, sync_token]() {
if (sync_token) {
// TODO: maybe we do not actually have to wait for the
// consumer sync here. Check docs.
sync_token->WaitOnGpu();
} else {
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback";
}
context->RunWithoutWaiting([name_to_delete]() {
// Note that we do not wait for consumers to be done before deleting the
// texture. Based on a reading of the GLES 3.0 spec, appendix D:
// - when a texture is deleted, it is _not_ automatically unbound from
// bind points in other contexts;
// - when a texture is deleted, its name becomes immediately invalid, but
// the actual object is not deleted until it is no longer in use, i.e.
// attached to a container object or bound to a context;
// - deleting an object is not an operation that changes its contents;
// - within each context, commands are executed sequentially, so it seems
// like an unbind that follows a command that reads a texture should not
// take effect until the GPU has actually finished executing the
// previous commands.
// The final point is the least explicit in the docs, but it is implied by
// normal single-context behavior. E.g. if you do bind, delete, render,
// unbind, the object is not deleted until the unbind, and it waits for
// the render to finish.
DLOG_IF(ERROR, !glIsTexture(name_to_delete))
<< "Deleting invalid texture id: " << name_to_delete;
glDeleteTextures(1, &name_to_delete);
@ -185,7 +194,10 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
<< "Updated existing texture which had not been marked for reuse!";
CHECK(prod_token);
producer_sync_ = std::move(prod_token);
producer_context_ = producer_sync_->GetContext();
const auto& synced_context = producer_sync_->GetContext();
if (synced_context) {
producer_context_ = synced_context;
}
}
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {

View File

@ -65,6 +65,7 @@ class GlTextureView {
friend class GpuBuffer;
friend class GlTextureBuffer;
friend class GpuBufferStorageCvPixelBuffer;
friend class GpuBufferStorageAhwb;
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
DetachFn detach, DoneWritingFn done_writing)

View File

@ -18,6 +18,7 @@
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h"
#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h"
#include "mediapipe/gpu/gpu_test_base.h"
#include "stb_image.h"

View File

@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.framework.image.ImageProperties;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.framework.image.MPImageProperties;
import java.nio.ByteBuffer;
// TODO: use Preconditions in this file.
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
}
/**
* Creates an Image packet from an {@link Image}.
* Creates a MediaPipe Image packet from a {@link MPImage}.
*
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
*/
public Packet createImage(Image image) {
public Packet createImage(MPImage image) {
// TODO: Choose the best storage from multiple containers.
ImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
MPImageProperties properties = image.getContainedImageProperties().get(0);
if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
ByteBuffer buffer = ByteBufferExtractor.extract(image);
int numChannels = 0;
switch (properties.getImageFormat()) {
case Image.IMAGE_FORMAT_RGBA:
case MPImage.IMAGE_FORMAT_RGBA:
numChannels = 4;
break;
case Image.IMAGE_FORMAT_RGB:
case MPImage.IMAGE_FORMAT_RGB:
numChannels = 3;
break;
case Image.IMAGE_FORMAT_ALPHA:
case MPImage.IMAGE_FORMAT_ALPHA:
numChannels = 1;
break;
default: // fall out
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
int height = image.getHeight();
return createImage(buffer, width, height, numChannels);
}
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) {
Bitmap bitmap = BitmapExtractor.extract(image);
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");

View File

@ -30,3 +30,10 @@ android_library(
"@maven//:com_google_guava_guava",
],
)
# Expose the java source files for building mediapipe AAR.
filegroup(
name = "java_src",
srcs = glob(["*.java"]),
visibility = ["//mediapipe:__subpackages__"],
)

View File

@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
import android.graphics.Bitmap;
/**
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
* Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
* {@link IllegalArgumentException} will be thrown.
*/
public final class BitmapExtractor {
/**
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
* Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
*
* @param image the image to extract {@link android.graphics.Bitmap} from.
* @return the {@link android.graphics.Bitmap} stored in {@link Image}
* @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions.
*/
public static Bitmap extract(Image image) {
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
public static Bitmap extract(MPImage image) {
MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
if (imageContainer != null) {
return ((BitmapImageContainer) imageContainer).getBitmap();
} else {
// TODO: Support ByteBuffer -> Bitmap conversion.
throw new IllegalArgumentException(
"Extracting Bitmap from an Image created by objects other than Bitmap is not"
"Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
+ " supported");
}
}

View File

@ -22,7 +22,7 @@ import android.provider.MediaStore;
import java.io.IOException;
/**
* Builds {@link Image} from {@link android.graphics.Bitmap}.
* Builds {@link MPImage} from {@link android.graphics.Bitmap}.
*
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
}
/**
* Creates the builder to build {@link Image} from a file.
* Creates the builder to build {@link MPImage} from a file.
*
* @param context the application context.
* @param uri the path to the resource file.
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
}
/** Sets value for {@link Image#getTimestamp()}. */
/** Sets value for {@link MPImage#getTimestamp()}. */
BitmapImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(
/** Builds a {@link MPImage} instance. */
public MPImage build() {
return new MPImage(
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
}
}

View File

@ -16,19 +16,19 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.graphics.Bitmap;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
class BitmapImageContainer implements ImageContainer {
class BitmapImageContainer implements MPImageContainer {
private final Bitmap bitmap;
private final ImageProperties properties;
private final MPImageProperties properties;
public BitmapImageContainer(Bitmap bitmap) {
this.bitmap = bitmap;
this.properties =
ImageProperties.builder()
MPImageProperties.builder()
.setImageFormat(convertFormatCode(bitmap.getConfig()))
.setStorageType(Image.STORAGE_TYPE_BITMAP)
.setStorageType(MPImage.STORAGE_TYPE_BITMAP)
.build();
}
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
}
@Override
public ImageProperties getImageProperties() {
public MPImageProperties getImageProperties() {
return properties;
}
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
bitmap.recycle();
}
@ImageFormat
@MPImageFormat
static int convertFormatCode(Bitmap.Config config) {
switch (config) {
case ALPHA_8:
return Image.IMAGE_FORMAT_ALPHA;
return MPImage.IMAGE_FORMAT_ALPHA;
case ARGB_8888:
return Image.IMAGE_FORMAT_RGBA;
return MPImage.IMAGE_FORMAT_RGBA;
default:
return Image.IMAGE_FORMAT_UNKNOWN;
return MPImage.IMAGE_FORMAT_UNKNOWN;
}
}
}

View File

@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Locale;
/**
* Utility for extracting {@link ByteBuffer} from {@link Image}.
* Utility for extracting {@link ByteBuffer} from {@link MPImage}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
* {@link IllegalArgumentException} will be thrown.
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
* otherwise {@link IllegalArgumentException} will be thrown.
*/
public class ByteBufferExtractor {
/**
* Extracts a {@link ByteBuffer} from an {@link Image}.
* Extracts a {@link ByteBuffer} from a {@link MPImage}.
*
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
* MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
*
* @see Image#getContainedImageProperties()
* @see MPImage#getContainedImageProperties()
* @return A read-only {@link ByteBuffer}.
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
*/
@SuppressLint("SwitchIntDef")
public static ByteBuffer extract(Image image) {
ImageContainer container = image.getContainer();
public static ByteBuffer extract(MPImage image) {
MPImageContainer container = image.getContainer();
switch (container.getImageProperties().getStorageType()) {
case Image.STORAGE_TYPE_BYTEBUFFER:
case MPImage.STORAGE_TYPE_BYTEBUFFER:
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
default:
throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
+ " supported");
}
}
/**
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
*
* <p>Format conversion spec:
*
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
*
* @param image the image to extract buffer from.
* @param targetFormat the image format of the result bytebuffer.
* @return the readonly {@link ByteBuffer} stored in {@link Image}
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions.
*/
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
ImageContainer container;
ImageProperties byteBufferProperties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
MPImageContainer container;
MPImageProperties byteBufferProperties =
MPImageProperties.builder()
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(targetFormat)
.build();
if ((container = image.getContainer(byteBufferProperties)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
.asReadOnlyBuffer();
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
ByteBuffer byteBuffer =
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
return byteBuffer;
} else {
throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by objects other than Bitmap or"
"Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
+ " Bytebuffer is not supported");
}
}
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
/** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
@AutoValue
abstract static class Result {
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
/**
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
public abstract ByteBuffer buffer();
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
@ImageFormat
/**
* Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
*/
@MPImageFormat
public abstract int format();
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
}
}
/**
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
*
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
*
* @return the readonly {@link ByteBuffer} stored in {@link Image}
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
* given {@code imageFormat}
*/
static Result extractInRecommendedFormat(Image image) {
ImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
static Result extractInRecommendedFormat(MPImage image) {
MPImageContainer container;
if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
@ImageFormat int format = adviseImageFormat(bitmap);
@MPImageFormat int format = adviseImageFormat(bitmap);
Result result =
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
boolean unused =
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
return result;
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return Result.create(
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
byteBufferImageContainer.getImageFormat());
} else {
throw new IllegalArgumentException(
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
"Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
+ " is not supported");
}
}
@ImageFormat
@MPImageFormat
private static int adviseImageFormat(Bitmap bitmap) {
if (bitmap.getConfig() == Config.ARGB_8888) {
return Image.IMAGE_FORMAT_RGBA;
return MPImage.IMAGE_FORMAT_RGBA;
} else {
throw new IllegalArgumentException(
String.format(
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
"Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
+ " supported",
bitmap.getConfig()));
}
}
private static ByteBuffer extractByteBufferFromBitmap(
Bitmap bitmap, @ImageFormat int imageFormat) {
Bitmap bitmap, @MPImageFormat int imageFormat) {
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
throw new IllegalArgumentException(
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
"Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
+ " supported");
}
if (bitmap.getConfig() == Config.ARGB_8888) {
if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
bitmap.copyPixelsToBuffer(buffer);
buffer.rewind();
return buffer;
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
} else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
int w = bitmap.getWidth();
int h = bitmap.getHeight();
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
}
throw new IllegalArgumentException(
String.format(
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
"Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
+ " %d is not supported",
bitmap.getConfig(), imageFormat));
}
private static ByteBuffer convertByteBuffer(
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the
// array reversely to convert in-place.
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
target.put(array, 0, target.capacity());
target.rewind();
return target;
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
} else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
// array to convert in-place.

View File

@ -15,11 +15,11 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer;
/**
* Builds a {@link Image} from a {@link ByteBuffer}.
* Builds a {@link MPImage} from a {@link ByteBuffer}.
*
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
private final ByteBuffer buffer;
private final int width;
private final int height;
@ImageFormat private final int imageFormat;
@MPImageFormat private final int imageFormat;
// Optional fields.
private long timestamp;
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
* @param imageFormat how the data encode the image.
*/
public ByteBufferImageBuilder(
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
this.buffer = byteBuffer;
this.width = width;
this.height = height;
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
this.timestamp = 0;
}
/** Sets value for {@link Image#getTimestamp()}. */
/** Sets value for {@link MPImage#getTimestamp()}. */
ByteBufferImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
/** Builds a {@link MPImage} instance. */
public MPImage build() {
return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
}
}

View File

@ -15,21 +15,19 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import java.nio.ByteBuffer;
class ByteBufferImageContainer implements ImageContainer {
class ByteBufferImageContainer implements MPImageContainer {
private final ByteBuffer buffer;
private final ImageProperties properties;
private final MPImageProperties properties;
public ByteBufferImageContainer(
ByteBuffer buffer,
@ImageFormat int imageFormat) {
public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
this.buffer = buffer;
this.properties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
MPImageProperties.builder()
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
.setImageFormat(imageFormat)
.build();
}
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
}
@Override
public ImageProperties getImageProperties() {
public MPImageProperties getImageProperties() {
return properties;
}
/**
* Returns the image format.
*/
@ImageFormat
/** Returns the image format. */
@MPImageFormat
public int getImageFormat() {
return properties.getImageFormat();
}

View File

@ -29,10 +29,10 @@ import java.util.Map.Entry;
/**
* The wrapper class for image objects.
*
* <p>{@link Image} is designed to be an immutable image container, which could be shared
* <p>{@link MPImage} is designed to be an immutable image container, which could be shared
* cross-platforms.
*
* <p>To construct an {@link Image}, use the provided builders:
* <p>To construct a {@link MPImage}, use the provided builders:
*
* <ul>
* <li>{@link ByteBufferImageBuilder}
@ -40,7 +40,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageBuilder}
* </ul>
*
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
* <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release
* internal storage earlier, otherwise Java garbage collection will release the storage eventually.
*
@ -53,7 +53,7 @@ import java.util.Map.Entry;
* <li>{@link MediaImageExtractor}
* </ul>
*/
public class Image implements Closeable {
public class MPImage implements Closeable {
/** Specifies the image format of an image. */
@IntDef({
@ -69,7 +69,7 @@ public class Image implements Closeable {
IMAGE_FORMAT_JPEG,
})
@Retention(RetentionPolicy.SOURCE)
public @interface ImageFormat {}
public @interface MPImageFormat {}
public static final int IMAGE_FORMAT_UNKNOWN = 0;
public static final int IMAGE_FORMAT_RGBA = 1;
@ -98,14 +98,14 @@ public class Image implements Closeable {
public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
/**
* Returns a list of supported image properties for this {@link Image}.
* Returns a list of supported image properties for this {@link MPImage}.
*
* <p>Currently {@link Image} only support single storage type so the size of return list will
* <p>Currently {@link MPImage} only support single storage type so the size of return list will
* always be 1.
*
* @see ImageProperties
* @see MPImageProperties
*/
public List<ImageProperties> getContainedImageProperties() {
public List<MPImageProperties> getContainedImageProperties() {
return Collections.singletonList(getContainer().getImageProperties());
}
@ -124,7 +124,7 @@ public class Image implements Closeable {
return height;
}
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
/** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
private synchronized void acquire() {
referenceCount += 1;
}
@ -132,7 +132,7 @@ public class Image implements Closeable {
/**
* Removes a reference that was previously acquired or init.
*
* <p>When {@link Image} is created, it has 1 reference count.
* <p>When {@link MPImage} is created, it has 1 reference count.
*
* <p>When the reference count becomes 0, it will release the resource under the hood.
*/
@ -141,24 +141,24 @@ public class Image implements Closeable {
public synchronized void close() {
referenceCount -= 1;
if (referenceCount == 0) {
for (ImageContainer imageContainer : containerMap.values()) {
for (MPImageContainer imageContainer : containerMap.values()) {
imageContainer.close();
}
}
}
/** Advanced API access for {@link Image}. */
/** Advanced API access for {@link MPImage}. */
static final class Internal {
/**
* Acquires a reference on this {@link Image}. This will increase the reference count by 1.
* Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
*
* <p>This method is more useful for image consumer to acquire a reference so image resource
* will not be closed accidentally. As image creator, normal developer doesn't need to call this
* method.
*
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link Image} anymore.
* <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
* #close()} to indicate it doesn't need this {@link MPImage} anymore.
*
* @see #close()
*/
@ -166,10 +166,10 @@ public class Image implements Closeable {
image.acquire();
}
private final Image image;
private final MPImage image;
// Only Image creates the internal helper.
private Internal(Image image) {
// Only MPImage creates the internal helper.
private Internal(MPImage image) {
this.image = image;
}
}
@ -179,15 +179,15 @@ public class Image implements Closeable {
return new Internal(this);
}
private final Map<ImageProperties, ImageContainer> containerMap;
private final Map<MPImageProperties, MPImageContainer> containerMap;
private final long timestamp;
private final int width;
private final int height;
private int referenceCount;
/** Constructs an {@link Image} with a built container. */
Image(ImageContainer container, long timestamp, int width, int height) {
/** Constructs a {@link MPImage} with a built container. */
MPImage(MPImageContainer container, long timestamp, int width, int height) {
this.containerMap = new HashMap<>();
containerMap.put(container.getImageProperties(), container);
this.timestamp = timestamp;
@ -201,10 +201,10 @@ public class Image implements Closeable {
*
* @return the current container.
*/
ImageContainer getContainer() {
MPImageContainer getContainer() {
// According to the design, in the future we will support multiple containers in one image.
// Currently just return the original container.
// TODO: Cache multiple containers in Image.
// TODO: Cache multiple containers in MPImage.
return containerMap.values().iterator().next();
}
@ -214,8 +214,8 @@ public class Image implements Closeable {
* <p>If there are multiple containers with required {@code storageType}, returns the first one.
*/
@Nullable
ImageContainer getContainer(@StorageType int storageType) {
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
MPImageContainer getContainer(@StorageType int storageType) {
for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
if (entry.getKey().getStorageType() == storageType) {
return entry.getValue();
}
@ -225,13 +225,13 @@ public class Image implements Closeable {
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
@Nullable
ImageContainer getContainer(ImageProperties imageProperties) {
MPImageContainer getContainer(MPImageProperties imageProperties) {
return containerMap.get(imageProperties);
}
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
boolean addContainer(ImageContainer container) {
ImageProperties imageProperties = container.getImageProperties();
boolean addContainer(MPImageContainer container) {
MPImageProperties imageProperties = container.getImageProperties();
if (containerMap.containsKey(imageProperties)) {
return false;
}

View File

@ -14,14 +14,14 @@ limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that can receive {@link Image} */
public interface ImageConsumer {
/** Lightweight abstraction for an object that can receive {@link MPImage} */
public interface MPImageConsumer {
/**
* Called when an {@link Image} is available.
* Called when a {@link MPImage} is available.
*
* <p>The argument is only guaranteed to be available until this method returns. if you need to
* extend its life time, acquire it, then release it when done.
*/
void onNewImage(Image image);
void onNewMPImage(MPImage image);
}

View File

@ -16,9 +16,9 @@ limitations under the License.
package com.google.mediapipe.framework.image;
/** Manages internal image data storage. The interface is package-private. */
interface ImageContainer {
interface MPImageContainer {
/** Returns the properties of the contained image. */
ImageProperties getImageProperties();
MPImageProperties getImageProperties();
/** Close the image container and releases the image resource inside. */
void close();

View File

@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
package com.google.mediapipe.framework.image;
/** Lightweight abstraction for an object that produce {@link Image} */
public interface ImageProducer {
/** Lightweight abstraction for an object that produce {@link MPImage} */
public interface MPImageProducer {
/** Sets the consumer that receives the {@link Image}. */
void setImageConsumer(ImageConsumer imageConsumer);
/** Sets the consumer that receives the {@link MPImage}. */
void setMPImageConsumer(MPImageConsumer imageConsumer);
}

View File

@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
import com.google.auto.value.AutoValue;
import com.google.auto.value.extension.memoized.Memoized;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.Image.StorageType;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
import com.google.mediapipe.framework.image.MPImage.StorageType;
/** Groups a set of properties to describe how an image is stored. */
@AutoValue
public abstract class ImageProperties {
public abstract class MPImageProperties {
/**
* Gets the pixel format of the image.
*
* @see Image.ImageFormat
* @see MPImage.MPImageFormat
*/
@ImageFormat
@MPImageFormat
public abstract int getImageFormat();
/**
* Gets the storage type of the image.
*
* @see Image.StorageType
* @see MPImage.StorageType
*/
@StorageType
public abstract int getStorageType();
@ -45,36 +45,36 @@ public abstract class ImageProperties {
public abstract int hashCode();
/**
* Creates a builder of {@link ImageProperties}.
* Creates a builder of {@link MPImageProperties}.
*
* @see ImageProperties.Builder
* @see MPImageProperties.Builder
*/
static Builder builder() {
return new AutoValue_ImageProperties.Builder();
return new AutoValue_MPImageProperties.Builder();
}
/** Builds a {@link ImageProperties}. */
/** Builds a {@link MPImageProperties}. */
@AutoValue.Builder
abstract static class Builder {
/**
* Sets the {@link Image.ImageFormat}.
* Sets the {@link MPImage.MPImageFormat}.
*
* @see ImageProperties#getImageFormat
* @see MPImageProperties#getImageFormat
*/
abstract Builder setImageFormat(@ImageFormat int value);
abstract Builder setImageFormat(@MPImageFormat int value);
/**
* Sets the {@link Image.StorageType}.
* Sets the {@link MPImage.StorageType}.
*
* @see ImageProperties#getStorageType
* @see MPImageProperties#getStorageType
*/
abstract Builder setStorageType(@StorageType int value);
/** Builds the {@link ImageProperties}. */
abstract ImageProperties build();
/** Builds the {@link MPImageProperties}. */
abstract MPImageProperties build();
}
// Hide the constructor.
ImageProperties() {}
MPImageProperties() {}
}

View File

@ -15,11 +15,12 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
/**
* Builds {@link Image} from {@link android.media.Image}.
* Builds {@link MPImage} from {@link android.media.Image}.
*
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
* content in it.
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
public class MediaImageBuilder {
// Mandatory fields.
private final android.media.Image mediaImage;
private final Image mediaImage;
// Optional fields.
private long timestamp;
@ -40,20 +41,20 @@ public class MediaImageBuilder {
*
* @param mediaImage image data object.
*/
public MediaImageBuilder(android.media.Image mediaImage) {
public MediaImageBuilder(Image mediaImage) {
this.mediaImage = mediaImage;
this.timestamp = 0;
}
/** Sets value for {@link Image#getTimestamp()}. */
/** Sets value for {@link MPImage#getTimestamp()}. */
MediaImageBuilder setTimestamp(long timestamp) {
this.timestamp = timestamp;
return this;
}
/** Builds an {@link Image} instance. */
public Image build() {
return new Image(
/** Builds a {@link MPImage} instance. */
public MPImage build() {
return new MPImage(
new MediaImageContainer(mediaImage),
timestamp,
mediaImage.getWidth(),

View File

@ -15,33 +15,34 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build;
import android.os.Build.VERSION;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
import com.google.mediapipe.framework.image.Image.ImageFormat;
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
@RequiresApi(VERSION_CODES.KITKAT)
class MediaImageContainer implements ImageContainer {
class MediaImageContainer implements MPImageContainer {
private final android.media.Image mediaImage;
private final ImageProperties properties;
private final Image mediaImage;
private final MPImageProperties properties;
public MediaImageContainer(android.media.Image mediaImage) {
public MediaImageContainer(Image mediaImage) {
this.mediaImage = mediaImage;
this.properties =
ImageProperties.builder()
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
MPImageProperties.builder()
.setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
.setImageFormat(convertFormatCode(mediaImage.getFormat()))
.build();
}
public android.media.Image getImage() {
public Image getImage() {
return mediaImage;
}
@Override
public ImageProperties getImageProperties() {
public MPImageProperties getImageProperties() {
return properties;
}
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
mediaImage.close();
}
@ImageFormat
@MPImageFormat
static int convertFormatCode(int graphicsFormat) {
// We only cover the format mentioned in
// https://developer.android.com/reference/android/media/Image#getFormat()
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
return Image.IMAGE_FORMAT_RGBA;
return MPImage.IMAGE_FORMAT_RGBA;
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
return Image.IMAGE_FORMAT_RGB;
return MPImage.IMAGE_FORMAT_RGB;
}
}
switch (graphicsFormat) {
case android.graphics.ImageFormat.JPEG:
return Image.IMAGE_FORMAT_JPEG;
return MPImage.IMAGE_FORMAT_JPEG;
case android.graphics.ImageFormat.YUV_420_888:
return Image.IMAGE_FORMAT_YUV_420_888;
return MPImage.IMAGE_FORMAT_YUV_420_888;
default:
return Image.IMAGE_FORMAT_UNKNOWN;
return MPImage.IMAGE_FORMAT_UNKNOWN;
}
}
}

View File

@ -15,13 +15,14 @@ limitations under the License.
package com.google.mediapipe.framework.image;
import android.media.Image;
import android.os.Build.VERSION_CODES;
import androidx.annotation.RequiresApi;
/**
* Utility for extracting {@link android.media.Image} from {@link Image}.
* Utility for extracting {@link android.media.Image} from {@link MPImage}.
*
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
* otherwise {@link IllegalArgumentException} will be thrown.
*/
@RequiresApi(VERSION_CODES.KITKAT)
@ -30,20 +31,20 @@ public class MediaImageExtractor {
private MediaImageExtractor() {}
/**
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
* {@link Image} that built from {@link MediaImageBuilder}.
* Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
* {@link MPImage} that built from {@link MediaImageBuilder}.
*
* @param image the image to extract {@link android.media.Image} from.
* @return {@link android.media.Image} that stored in {@link Image}.
* @return {@link android.media.Image} that stored in {@link MPImage}.
* @throws IllegalArgumentException if the extraction failed.
*/
public static android.media.Image extract(Image image) {
ImageContainer container;
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
public static Image extract(MPImage image) {
MPImageContainer container;
if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
return ((MediaImageContainer) container).getImage();
}
throw new IllegalArgumentException(
"Extract Media Image from an Image created by objects other than Media Image"
"Extract Media Image from a MPImage created by objects other than Media Image"
+ " is not supported");
}
}

View File

@ -1,4 +1,4 @@
# Copyright 2019-2020 The MediaPipe Authors.
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -89,10 +89,6 @@ def mediapipe_aar(
calculators = calculators,
)
_mediapipe_proto(
name = name + "_proto",
)
native.genrule(
name = name + "_aar_manifest_generator",
outs = ["AndroidManifest.xml"],
@ -115,19 +111,10 @@ EOF
"//mediapipe/java/com/google/mediapipe/components:java_src",
"//mediapipe/java/com/google/mediapipe/framework:java_src",
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
"com/google/mediapipe/formats/proto/ClassificationProto.java",
"com/google/mediapipe/formats/proto/DetectionProto.java",
"com/google/mediapipe/formats/proto/LandmarkProto.java",
"com/google/mediapipe/formats/proto/LocationDataProto.java",
"com/google/mediapipe/proto/CalculatorProto.java",
] +
] + mediapipe_java_proto_srcs() +
select({
"//conditions:default": [],
"enable_stats_logging": [
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
],
"enable_stats_logging": mediapipe_logging_java_proto_srcs(),
}),
manifest = "AndroidManifest.xml",
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
@ -177,93 +164,9 @@ EOF
assets_dir = assets_dir,
)
_aar_with_jni(name, name + "_android_lib")
def _mediapipe_proto(name):
"""Generates MediaPipe java proto libraries.
Args:
name: the name of the target.
"""
_proto_java_src_generator(
name = "mediapipe_log_extension_proto",
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "mediapipe_logging_enums_proto",
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
srcs = ["//mediapipe/util/analytics:protos_src"],
)
_proto_java_src_generator(
name = "calculator_proto",
proto_src = "mediapipe/framework/calculator.proto",
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
srcs = ["//mediapipe/framework:protos_src"],
)
_proto_java_src_generator(
name = "landmark_proto",
proto_src = "mediapipe/framework/formats/landmark.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
srcs = ["//mediapipe/framework/formats:protos_src"],
)
_proto_java_src_generator(
name = "rasterization_proto",
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
)
_proto_java_src_generator(
name = "location_data_proto",
proto_src = "mediapipe/framework/formats/location_data.proto",
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "detection_proto",
proto_src = "mediapipe/framework/formats/detection.proto",
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
"//mediapipe/framework/formats/annotation:protos_src",
],
)
_proto_java_src_generator(
name = "classification_proto",
proto_src = "mediapipe/framework/formats/classification.proto",
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
srcs = [
"//mediapipe/framework/formats:protos_src",
],
)
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
native.genrule(
name = name + "_proto_java_src_generator",
srcs = srcs + [
"@com_google_protobuf//:lite_well_known_protos",
],
outs = [java_lite_out],
cmd = "$(location @com_google_protobuf//:protoc) " +
"--proto_path=. --proto_path=$(GENDIR) " +
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
tools = [
"@com_google_protobuf//:protoc",
],
mediapipe_build_aar_with_jni(
name = name,
android_library = name + "_android_lib",
)
def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
@ -303,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
alwayslink = 1,
)
def _aar_with_jni(name, android_library):
def mediapipe_build_aar_with_jni(name, android_library):
"""Builds MediaPipe AAR with jni.
Args:
name: The bazel target name.
android_library: the android library that contains jni.
"""
# Generates dummy AndroidManifest.xml for dummy apk usage
# (dummy apk is generated by <name>_dummy_app target below)
native.genrule(
@ -314,7 +224,7 @@ cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
package="dummy.package.for.so">
<uses-sdk android:minSdkVersion="21"/>
<uses-sdk android:minSdkVersion="24"/>
</manifest>
EOF
""",
@ -341,7 +251,133 @@ chmod +w $(location :{}.aar)
origdir=$$PWD
cd $$(mktemp -d)
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
find lib -name *_dummy_app.so -delete
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
)
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
"""Extracts the generated MediaPipe java proto source code from the target.
Args:
target: The java proto lite target to be built and extracted.
src_out: The output java proto src code path.
name: The optional bazel target name.
Returns:
The output java proto src code path.
"""
if not name:
name = target.split(":")[-1] + "_proto_java_src_extractor"
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
native.genrule(
name = name + "_proto_java_src_extractor",
srcs = [target],
outs = [src_out],
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
src_out + " $$(dirname $(location " + src_out + "))",
)
return src_out
def mediapipe_java_proto_srcs(name = ""):
"""Extracts the generated MediaPipe framework java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:calculator_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:stream_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StreamHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_factory_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketFactoryProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:packet_generator_java_proto_lite",
src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:status_handler_java_proto_lite",
src_out = "com/google/mediapipe/proto/StatusHandlerProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework:mediapipe_options_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:classification_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:detection_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
))
return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""):
"""Extracts the generated logging-related MediaPipe java proto source code.
Args:
name: The optional bazel target name.
Returns:
The list of the extrated MediaPipe logging-related java proto source code.
"""
proto_src_list = []
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
))
return proto_src_list

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,22 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
package_group(
name = "internal",
packages = [
"//mediapipe/model_maker/...",
],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,26 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
licenses(["notice"])
py_library(
name = "hyperparameters",
srcs = ["hyperparameters.py"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,61 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict library and test compatibility macro.
# Placeholder for internal Python strict test compatibility macro.
licenses(["notice"])
package(
default_visibility = ["//mediapipe:__subpackages__"],
)
py_library(
name = "data_util",
srcs = ["data_util.py"],
)
py_test(
name = "data_util_test",
srcs = ["data_util_test.py"],
data = ["//mediapipe/model_maker/python/core/data/testdata"],
deps = [":data_util"],
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
srcs_version = "PY3",
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
deps = [
":dataset",
"//mediapipe/model_maker/python/core/utils:test_util",
],
)
py_library(
name = "classification_dataset",
srcs = ["classification_dataset.py"],
deps = [":dataset"],
)
py_test(
name = "classification_dataset_test",
srcs = ["classification_dataset_test.py"],
deps = [":classification_dataset"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,51 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common classification dataset library."""
from typing import Any, Tuple
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
class ClassificationDataset(ds.Dataset):
"""DataLoader for classification models."""
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
super().__init__(dataset, size)
self._index_by_label = index_by_label
@property
def num_classes(self: ds._DatasetT) -> int:
return len(self._index_by_label)
@property
def index_by_label(self: ds._DatasetT) -> Any:
return self._index_by_label
def split(self: ds._DatasetT,
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: float, demonstrates the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction, self._index_by_label)

View File

@ -0,0 +1,82 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
_DatasetT = TypeVar(
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
class ClassificationDatasetTest(tf.test.TestCase):
def test_split(self):
class MagicClassificationDataset(
classification_dataset.ClassificationDataset):
"""A mock classification dataset class for testing purpose.
Attributes:
value: A value variable stored by the mock dataset class for testing.
"""
def __init__(self, dataset: tf.data.Dataset, size: int,
index_by_label: Any, value: Any):
super().__init__(
dataset=dataset, size=size, index_by_label=index_by_label)
self.value = value
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
return self._split(fraction, self.index_by_label, self.value)
# Some dummy inputs.
magic_value = 42
num_classes = 2
index_by_label = (False, True)
# Create data loader from sample data.
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
data = MagicClassificationDataset(
dataset=ds,
size=len(ds),
index_by_label=index_by_label,
value=magic_value)
# Train/Test data split.
fraction = .25
train_data, test_data = data.split(fraction=fraction)
# `split` should return instances of child DataLoader.
self.assertIsInstance(train_data, MagicClassificationDataset)
self.assertIsInstance(test_data, MagicClassificationDataset)
# Make sure number of entries are right.
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
self.assertLen(train_data, fraction * len(ds))
self.assertLen(test_data, len(ds) - len(train_data))
# Make sure attributes propagated correctly.
self.assertEqual(train_data.num_classes, num_classes)
self.assertEqual(test_data.index_by_label, index_by_label)
self.assertEqual(train_data.value, magic_value)
self.assertEqual(test_data.value, magic_value)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,35 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utility library."""
import cv2
import numpy as np
import tensorflow as tf
def load_image(path: str) -> np.ndarray:
"""Loads an image as an RGB numpy array.
Args:
path: input image file absolute path.
Returns:
An RGB image in numpy.ndarray.
"""
tf.compat.v1.logging.info('Loading RGB image %s', path)
# TODO Replace the OpenCV image load and conversion library by
# MediaPipe image utility library once it is ready.
image = cv2.imread(path)
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

View File

@ -0,0 +1,44 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
# Dependency imports
from absl import flags
import tensorflow as tf
from mediapipe.model_maker.python.core.data import data_util
_WORKSPACE = "mediapipe"
_TEST_DATA_DIR = os.path.join(
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
FLAGS = flags.FLAGS
class DataUtilTest(tf.test.TestCase):
def test_load_rgb_image(self):
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
image_data = data_util.load_image(image_path)
self.assertEqual(image_data.shape, (5184, 3456, 3))
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,164 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common dataset for model training and evaluation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from typing import Callable, Optional, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
class Dataset(object):
"""A generic dataset class for loading model training and evaluation dataset.
For each ML task, such as image classification, text classification etc., a
subclass can be derived from this class to provide task-specific data loading
utilities.
"""
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
"""Initializes Dataset class.
To build dataset from raw data, consider using the task specific utilities,
e.g. from_folder().
Args:
tf_dataset: A tf.data.Dataset object that contains a potentially large set
of elements, where each element is a pair of (input_data, target). The
`input_data` means the raw input data, like an image, a text etc., while
the `target` means the ground truth of the raw input data, e.g. the
classification label of the image etc.
size: The size of the dataset. tf.data.Dataset donesn't support a function
to get the length directly since it's lazy-loaded and may be infinite.
"""
self._dataset = tf_dataset
self._size = size
@property
def size(self) -> Optional[int]:
"""Returns the size of the dataset.
Note that this function may return None becuase the exact size of the
dataset isn't a necessary parameter to create an instance of this class,
and tf.data.Dataset donesn't support a function to get the length directly
since it's lazy-loaded and may be infinite.
In most cases, however, when an instance of this class is created by helper
functions like 'from_folder', the size of the dataset will be preprocessed,
and this function can return an int representing the size of the dataset.
"""
return self._size
def gen_tf_dataset(self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None,
drop_remainder: bool = False) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation.
Args:
batch_size: An integer, the returned dataset will be batched by this size.
is_training: A boolean, when True, the returned dataset will be optionally
shuffled and repeated as an endless dataset.
shuffle: A boolean, when True, the returned dataset will be shuffled to
create randomness during model training.
preprocess: A function taking three arguments in order, feature, label and
boolean is_training.
drop_remainder: boolean, whether the finaly batch drops remainder.
Returns:
A TF dataset ready to be consumed by Keras model.
"""
dataset = self._dataset
if preprocess:
preprocess = functools.partial(preprocess, is_training=is_training)
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
if is_training:
if shuffle:
# Shuffle size should be bigger than the batch_size. Otherwise it's only
# shuffling within the batch, which equals to not having shuffle.
buffer_size = 3 * batch_size
# But since we are doing shuffle before repeat, it doesn't make sense to
# shuffle more than total available entries.
# TODO: Investigate if shuffling before / after repeat
# dataset can get a better performance?
# Shuffle after repeat will give a more randomized dataset and mix the
# epoch boundary: https://www.tensorflow.org/guide/data
if self._size:
buffer_size = min(self._size, buffer_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
# TODO: Consider converting dataset to distributed dataset
# here.
return dataset
def __len__(self):
"""Returns the number of element of the dataset."""
if self._size is not None:
return self._size
else:
return len(self._dataset)
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
"""Splits dataset into two sub-datasets with the given fraction.
Primarily used for splitting the data set into training and testing sets.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
Returns:
The splitted two sub datasets.
"""
return self._split(fraction)
def _split(self: _DatasetT, fraction: float,
*args) -> Tuple[_DatasetT, _DatasetT]:
"""Implementation for `split` method and returns sub-class instances.
Child DataLoader classes, if requires additional constructor arguments,
should implement their own `split` method by calling `_split` with all
arguments to the constructor.
Args:
fraction: A float value defines the fraction of the first returned
subdataset in the original data.
*args: additional arguments passed to the sub-class constructor.
Returns:
The splitted two sub datasets.
"""
assert (fraction > 0 and fraction < 1)
dataset = self._dataset
train_size = int(self._size * fraction)
trainset = self.__class__(dataset.take(train_size), train_size, *args)
test_size = self._size - train_size
testset = self.__class__(dataset.skip(train_size), test_size, *args)
return trainset, testset

View File

@ -0,0 +1,78 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
import tensorflow as tf
from mediapipe.model_maker.python.core.data import dataset as ds
from mediapipe.model_maker.python.core.utils import test_util
class DatasetTest(tf.test.TestCase):
def test_split(self):
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, 4)
train_data, test_data = data.split(0.5)
self.assertLen(train_data, 2)
self.assertIsInstance(train_data, ds.Dataset)
self.assertIsInstance(test_data, ds.Dataset)
for i, elem in enumerate(train_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
self.assertLen(test_data, 2)
for i, elem in enumerate(test_data.gen_tf_dataset()):
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
def test_len(self):
size = 4
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
[1, 0]])
data = ds.Dataset(dataset, size)
self.assertLen(data, size)
def test_gen_tf_dataset(self):
input_dim = 8
data = test_util.create_dataset(
data_size=2, input_shape=[input_dim], num_classes=2)
dataset = data.gen_tf_dataset()
self.assertLen(dataset, 2)
for (feature, label) in dataset:
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
dataset2 = data.gen_tf_dataset(batch_size=2)
self.assertLen(dataset2, 1)
for (feature, label) in dataset2:
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
self.assertEqual(dataset3.cardinality(), 1)
for (feature, label) in dataset3.take(10):
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,30 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//mediapipe/framework/tool:mediapipe_files.bzl",
"mediapipe_files",
)
package(
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
licenses = ["notice"], # Apache 2.0
)
mediapipe_files(srcs = ["test.jpg"])
filegroup(
name = "testdata",
srcs = ["test.jpg"],
)

View File

@ -0,0 +1,68 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparameters for training models. Shared across tasks."""
import dataclasses
import tempfile
from typing import Optional
# TODO: Integrate this class into ImageClassifier and other tasks.
@dataclasses.dataclass
class BaseHParams:
"""Hyperparameters used for training models.
A common set of hyperparameters shared by the training jobs of all model
maker tasks.
Attributes:
learning_rate: The learning rate to use for gradient descent training.
batch_size: Batch size for training.
epochs: Number of training iterations over the dataset.
steps_per_epoch: An optional integer indicate the number of training steps
per epoch. If not set, the training pipeline calculates the default steps
per epoch as the training dataset size devided by batch size.
shuffle: True if the dataset is shuffled before training.
export_dir: The location of the model checkpoint files.
distribution_strategy: A string specifying which Distribution Strategy to
use. Accepted values are 'off', 'one_device', 'mirrored',
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to
use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy
documentation for more details:
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
num_gpus: How many GPUs to use at each worker with the
DistributionStrategies API. The default is -1, which means utilize all
available GPUs.
tpu: The Cloud TPU to use for training. This should be either the name used
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
"""
# Parameters for train configuration
learning_rate: float
batch_size: int
epochs: int
steps_per_epoch: Optional[int] = None
# Dataset-related parameters
shuffle: bool = False
# Parameters for model / checkpoint files
export_dir: str = tempfile.mkdtemp()
# Parameters for hardware acceleration
distribution_strategy: str = 'off'
num_gpus: int = -1 # default value of -1 means use all available GPUs
tpu: str = ''

Some files were not shown because too many files have changed in this diff Show More