Merge 6b0a7fb281
into 1d6750d240
This commit is contained in:
commit
0f4379cd64
|
@ -53,7 +53,7 @@ RUN pip3 install wheel
|
|||
RUN pip3 install future
|
||||
RUN pip3 install absl-py numpy opencv-contrib-python protobuf==3.20.1
|
||||
RUN pip3 install six==1.14.0
|
||||
RUN pip3 install tensorflow==2.2.0
|
||||
RUN pip3 install tensorflow
|
||||
RUN pip3 install tf_slim
|
||||
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
|
|
@ -143,6 +143,98 @@ Below is an example of how to create a subgraph named `TwoPassThroughSubgraph`.
|
|||
}
|
||||
```
|
||||
|
||||
## Graph Options
|
||||
|
||||
It is possible to specify a "graph options" protobuf for a MediaPipe graph
|
||||
similar to the [`Calculator Options`](calculators.md#calculator-options)
|
||||
protobuf specified for a MediaPipe calculator. These "graph options" can be
|
||||
specified where a graph is invoked, and used to populate calculator options and
|
||||
subgraph options within the graph.
|
||||
|
||||
In a CalculatorGraphConfig, graph options can be specified for a subgraph
|
||||
exactly like calculator options, as shown below:
|
||||
|
||||
```
|
||||
node {
|
||||
calculator: "FlowLimiterCalculator"
|
||||
input_stream: "image"
|
||||
output_stream: "throttled_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.FlowLimiterCalculatorOptions] {
|
||||
max_in_flight: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "FaceDetectionSubgraph"
|
||||
input_stream: "IMAGE:throttled_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.FaceDetectionOptions] {
|
||||
tensor_width: 192
|
||||
tensor_height: 192
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In a CalculatorGraphConfig, graph options can be accepted and used to populate
|
||||
calculator options, as shown below:
|
||||
|
||||
```
|
||||
graph_options: {
|
||||
[type.googleapis.com/mediapipe.FaceDetectionOptions] {}
|
||||
}
|
||||
|
||||
node: {
|
||||
calculator: "ImageToTensorCalculator"
|
||||
input_stream: "IMAGE:multi_backend_image"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.ImageToTensorCalculatorOptions] {
|
||||
keep_aspect_ratio: true
|
||||
border_mode: BORDER_ZERO
|
||||
}
|
||||
}
|
||||
option_value: "output_tensor_width:options/tensor_width"
|
||||
option_value: "output_tensor_height:options/tensor_height"
|
||||
}
|
||||
|
||||
node {
|
||||
calculator: "InferenceCalculator"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.InferenceCalculatorOptions] {}
|
||||
}
|
||||
option_value: "delegate:options/delegate"
|
||||
option_value: "model_path:options/model_path"
|
||||
}
|
||||
```
|
||||
|
||||
In this example, the `FaceDetectionSubgraph` accepts graph option protobuf
|
||||
`FaceDetectionOptions`. The `FaceDetectionOptions` is used to define some field
|
||||
values in the calculator options `ImageToTensorCalculatorOptions` and some field
|
||||
values in the subgraph options `InferenceCalculatorOptions`. The field values
|
||||
are defined using the `option_value:` syntax.
|
||||
|
||||
In the `CalculatorGraphConfig::Node` protobuf, the fields `node_options:` and
|
||||
`option_value:` together define the option values for a calculator such as
|
||||
`ImageToTensorCalculator`. The `node_options:` field defines a set of literal
|
||||
constant values using the text protobuf syntax. Each `option_value:` field
|
||||
defines the value for one protobuf field using information from the enclosing
|
||||
graph, specifically from field values of the graph options of the enclosing
|
||||
graph. In the example above, the `option_value:`
|
||||
`"output_tensor_width:options/tensor_width"` defines the field
|
||||
`ImageToTensorCalculatorOptions.output_tensor_width` using the value of
|
||||
`FaceDetectionOptions.tensor_width`.
|
||||
|
||||
The syntax of `option_value:` is similar to the syntax of `input_stream:`. The
|
||||
syntax is `option_value: "LHS:RHS"`. The LHS identifies a calculator option
|
||||
field and the RHS identifies a graph option field. More specifically, the LHS
|
||||
and RHS each consists of a series of protobuf field names identifying nested
|
||||
protobuf messages and fields separated by '/'. This is known as the "ProtoPath"
|
||||
syntax. Nested messages that are referenced in the LHS or RHS must already be
|
||||
defined in the enclosing protobuf in order to be traversed using
|
||||
`option_value:`.
|
||||
|
||||
## Cycles
|
||||
|
||||
<!-- TODO: add discussion of PreviousLoopbackCalculator -->
|
||||
|
|
|
@ -54,7 +54,7 @@ Note: This currently works only on Linux, and please first follow
|
|||
|
||||
```bash
|
||||
GLOG_logtostderr=1 bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
|
||||
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live_gpu.pbtxt
|
||||
```
|
||||
|
||||
This will open up your webcam as long as it is connected and on. Any errors
|
||||
|
|
|
@ -1410,3 +1410,45 @@ cc_library(
|
|||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "bypass_calculator_proto",
|
||||
srcs = ["bypass_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bypass_calculator",
|
||||
srcs = ["bypass_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bypass_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "bypass_calculator_test",
|
||||
srcs = ["bypass_calculator_test.cc"],
|
||||
deps = [
|
||||
":bypass_calculator",
|
||||
":pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:switch_container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
|
161
mediapipe/calculators/core/bypass_calculator.cc
Normal file
161
mediapipe/calculators/core/bypass_calculator.cc
Normal file
|
@ -0,0 +1,161 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/bypass_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using mediapipe::BypassCalculatorOptions;
|
||||
|
||||
// Defines a "bypass" channel to use in place of a disabled feature subgraph.
|
||||
// By default, all inputs are discarded and all outputs are ignored.
|
||||
// Certain input streams can be passed to corresponding output streams
|
||||
// by specifying them in "pass_input_stream" and "pass_output_stream" options.
|
||||
// All output streams are updated with timestamp bounds indicating completed
|
||||
// output.
|
||||
//
|
||||
// Note that this calculator is designed for use as a contained_node in a
|
||||
// SwitchContainer. For this reason, any input and output tags are accepted,
|
||||
// and stream semantics are specified through BypassCalculatorOptions.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "BypassCalculator"
|
||||
// input_stream: "APPEARANCES:appearances_post_facenet"
|
||||
// input_stream: "VIDEO:video_frame"
|
||||
// input_stream: "FEATURE_CONFIG:feature_config"
|
||||
// input_stream: "ENABLE:gaze_enabled"
|
||||
// output_stream: "APPEARANCES:analyzed_appearances"
|
||||
// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
// node_options: {
|
||||
// [type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||
// pass_input_stream: "APPEARANCES"
|
||||
// pass_output_stream: "APPEARANCES"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
class BypassCalculator : public Node {
|
||||
public:
|
||||
static constexpr mediapipe::api2::Input<int>::Optional kNotNeeded{"N_N_"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kNotNeeded);
|
||||
using IdMap = std::map<CollectionItemId, CollectionItemId>;
|
||||
|
||||
// Returns the map of passthrough input and output stream ids.
|
||||
static absl::StatusOr<IdMap> GetPassMap(
|
||||
const BypassCalculatorOptions& options, const tool::TagMap& input_map,
|
||||
const tool::TagMap& output_map) {
|
||||
IdMap result;
|
||||
auto& input_streams = options.pass_input_stream();
|
||||
auto& output_streams = options.pass_output_stream();
|
||||
int size = std::min(input_streams.size(), output_streams.size());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
std::pair<std::string, int> in_tag, out_tag;
|
||||
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i),
|
||||
&in_tag.first, &in_tag.second));
|
||||
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i),
|
||||
&out_tag.first, &out_tag.second));
|
||||
auto input_id = input_map.GetId(in_tag.first, in_tag.second);
|
||||
auto output_id = output_map.GetId(out_tag.first, out_tag.second);
|
||||
result[input_id] = output_id;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Identifies all specified streams as "Any" packet type.
|
||||
// Identifies passthrough streams as "Same" packet type.
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
auto options = cc->Options<BypassCalculatorOptions>();
|
||||
RET_CHECK_EQ(options.pass_input_stream().size(),
|
||||
options.pass_output_stream().size());
|
||||
ASSIGN_OR_RETURN(
|
||||
auto pass_streams,
|
||||
GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap()));
|
||||
std::set<CollectionItemId> pass_out;
|
||||
for (auto entry : pass_streams) {
|
||||
pass_out.insert(entry.second);
|
||||
cc->Inputs().Get(entry.first).SetAny();
|
||||
cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first));
|
||||
}
|
||||
for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) {
|
||||
if (pass_streams.count(id) == 0) {
|
||||
cc->Inputs().Get(id).SetAny();
|
||||
}
|
||||
}
|
||||
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
|
||||
if (pass_out.count(id) == 0) {
|
||||
cc->Outputs().Get(id).SetAny();
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Saves the map of passthrough input and output stream ids.
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
auto options = cc->Options<BypassCalculatorOptions>();
|
||||
ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(),
|
||||
*cc->Outputs().TagMap()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Copies packets between passthrough input and output streams.
|
||||
// Updates timestamp bounds on all output streams.
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
std::set<CollectionItemId> pass_out;
|
||||
for (auto entry : pass_streams_) {
|
||||
pass_out.insert(entry.second);
|
||||
auto& packet = cc->Inputs().Get(entry.first).Value();
|
||||
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||
cc->Outputs().Get(entry.first).AddPacket(packet);
|
||||
}
|
||||
}
|
||||
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
|
||||
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
|
||||
if (pass_out.count(id) == 0) {
|
||||
cc->Outputs().Get(id).SetNextTimestampBound(
|
||||
std::max(cc->Outputs().Get(id).NextTimestampBound(), bound));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Close all output streams.
|
||||
absl::Status Close(CalculatorContext* cc) override {
|
||||
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
|
||||
cc->Outputs().Get(id).Close();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
IdMap pass_streams_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(BypassCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
31
mediapipe/calculators/core/bypass_calculator.proto
Normal file
31
mediapipe/calculators/core/bypass_calculator.proto
Normal file
|
@ -0,0 +1,31 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message BypassCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional BypassCalculatorOptions ext = 481259677;
|
||||
}
|
||||
|
||||
// Names an input stream or streams to pass through, by "TAG:index".
|
||||
repeated string pass_input_stream = 1;
|
||||
|
||||
// Names an output stream or streams to pass through, by "TAG:index".
|
||||
repeated string pass_output_stream = 2;
|
||||
}
|
302
mediapipe/calculators/core/bypass_calculator_test.cc
Normal file
302
mediapipe/calculators/core/bypass_calculator_test.cc
Normal file
|
@ -0,0 +1,302 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
// A graph with using a BypassCalculator to pass through and ignore
|
||||
// most of its inputs and outputs.
|
||||
constexpr char kTestGraphConfig1[] = R"pb(
|
||||
type: "AppearancesPassThroughSubgraph"
|
||||
input_stream: "APPEARANCES:appearances"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
output_stream: "APPEARANCES:passthrough_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "BypassCalculator"
|
||||
input_stream: "PASS:appearances"
|
||||
input_stream: "TRUNCATE:0:video_frame"
|
||||
input_stream: "TRUNCATE:1:feature_config"
|
||||
output_stream: "PASS:passthrough_appearances"
|
||||
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||
pass_input_stream: "PASS"
|
||||
pass_output_stream: "PASS"
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel
|
||||
// for input frames and appearances.
|
||||
constexpr char kTestGraphConfig2[] = R"pb(
|
||||
input_stream: "VIDEO_FULL_RES:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "GAZE_ENABLED:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "SwitchContainer"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "ENABLE:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// A graph with using BypassCalculator as a do-nothing channel
|
||||
// for input frames and appearances.
|
||||
constexpr char kTestGraphConfig3[] = R"pb(
|
||||
input_stream: "VIDEO_FULL_RES:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "GAZE_ENABLED:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "SwitchContainer"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "ENABLE:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
contained_node: {
|
||||
calculator: "BypassCalculator"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||
pass_input_stream: "APPEARANCES"
|
||||
pass_output_stream: "APPEARANCES"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// A graph with using BypassCalculator as a disabled-gate
|
||||
// for input frames and appearances.
|
||||
constexpr char kTestGraphConfig4[] = R"pb(
|
||||
input_stream: "VIDEO_FULL_RES:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "GAZE_ENABLED:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "SwitchContainer"
|
||||
input_stream: "ENABLE:gaze_enabled"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
output_stream: "VIDEO:video_frame_out"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEATURE_CONFIG:feature_config_out"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
contained_node: { calculator: "BypassCalculator" }
|
||||
contained_node: { calculator: "PassThroughCalculator" }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// Reports packet timestamp and string contents, or "<empty>"".
|
||||
std::string DebugString(Packet p) {
|
||||
return absl::StrCat(p.Timestamp().DebugString(), ":",
|
||||
p.IsEmpty() ? "<empty>" : p.Get<std::string>());
|
||||
}
|
||||
|
||||
// Shows a bypass subgraph that passes through one stream.
|
||||
TEST(BypassCalculatorTest, SubgraphChannel) {
|
||||
CalculatorGraphConfig config_1 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig1);
|
||||
CalculatorGraphConfig config_2 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig2);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {}));
|
||||
|
||||
std::vector<std::string> analyzed_appearances;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"analyzed_appearances",
|
||||
[&](const Packet& p) {
|
||||
analyzed_appearances.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
std::vector<std::string> federated_gaze_output;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"federated_gaze_output",
|
||||
[&](const Packet& p) {
|
||||
federated_gaze_output.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
|
||||
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Shows a BypassCalculator that passes through one stream.
|
||||
TEST(BypassCalculatorTest, CalculatorChannel) {
|
||||
CalculatorGraphConfig config_3 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig3);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
|
||||
|
||||
std::vector<std::string> analyzed_appearances;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"analyzed_appearances",
|
||||
[&](const Packet& p) {
|
||||
analyzed_appearances.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
std::vector<std::string> federated_gaze_output;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"federated_gaze_output",
|
||||
[&](const Packet& p) {
|
||||
federated_gaze_output.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
|
||||
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Shows a BypassCalculator that discards all inputs when ENABLED is false.
|
||||
TEST(BypassCalculatorTest, GatedChannel) {
|
||||
CalculatorGraphConfig config_3 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig4);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
|
||||
|
||||
std::vector<std::string> analyzed_appearances;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"analyzed_appearances",
|
||||
[&](const Packet& p) {
|
||||
analyzed_appearances.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
std::vector<std::string> video_frame;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"video_frame_out",
|
||||
[&](const Packet& p) {
|
||||
video_frame.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// Close the gate.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"gaze_enabled", MakePacket<bool>(false).At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Send packets at timestamp 200.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Only timestamps arrive from the BypassCalculator.
|
||||
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:<empty>"));
|
||||
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>"));
|
||||
|
||||
// Open the gate.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"gaze_enabled", MakePacket<bool>(true).At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Send packets at timestamp 300.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a2").At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v2").At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f2").At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Packets arrive from the PassThroughCalculator.
|
||||
EXPECT_THAT(analyzed_appearances,
|
||||
testing::ElementsAre("200:<empty>", "300:a2"));
|
||||
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>", "300:v2"));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
|
@ -209,11 +209,18 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "rotation_mode_proto",
|
||||
srcs = ["rotation_mode.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "image_transformation_calculator_proto",
|
||||
srcs = ["image_transformation_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":rotation_mode_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/gpu:scale_mode_proto",
|
||||
|
@ -238,6 +245,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":rotation_mode_cc_proto",
|
||||
":image_transformation_calculator_cc_proto",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:timestamp",
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/image/image_transformation_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/rotation_mode.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
|
|
|
@ -16,20 +16,10 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/calculators/image/rotation_mode.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/gpu/scale_mode.proto";
|
||||
|
||||
// Counterclockwise rotation.
|
||||
message RotationMode {
|
||||
enum Mode {
|
||||
UNKNOWN = 0;
|
||||
ROTATION_0 = 1;
|
||||
ROTATION_90 = 2;
|
||||
ROTATION_180 = 3;
|
||||
ROTATION_270 = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message ImageTransformationCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional ImageTransformationCalculatorOptions ext = 251952830;
|
||||
|
|
31
mediapipe/calculators/image/rotation_mode.proto
Normal file
31
mediapipe/calculators/image/rotation_mode.proto
Normal file
|
@ -0,0 +1,31 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
option java_package = "com.google.mediapipe.calculator.proto";
|
||||
option java_outer_classname = "RotationModeProto";
|
||||
|
||||
// Counterclockwise rotation.
|
||||
message RotationMode {
|
||||
enum Mode {
|
||||
UNKNOWN = 0;
|
||||
ROTATION_0 = 1;
|
||||
ROTATION_90 = 2;
|
||||
ROTATION_180 = 3;
|
||||
ROTATION_270 = 4;
|
||||
}
|
||||
}
|
|
@ -161,6 +161,193 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "bert_preprocessor_calculator_proto",
|
||||
srcs = ["bert_preprocessor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bert_preprocessor_calculator",
|
||||
srcs = ["bert_preprocessor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":bert_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "bert_preprocessor_calculator_test",
|
||||
srcs = ["bert_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:bert_text_classifier_models"],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":bert_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "regex_preprocessor_calculator_proto",
|
||||
srcs = ["regex_preprocessor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "regex_preprocessor_calculator",
|
||||
srcs = ["regex_preprocessor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":regex_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:regex_tokenizer",
|
||||
"//mediapipe/tasks/cc/text/tokenizers:tokenizer_utils",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "regex_preprocessor_calculator_test",
|
||||
srcs = ["regex_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:text_classifier_models"],
|
||||
linkopts = ["-ldl"],
|
||||
deps = [
|
||||
":regex_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_to_tensor_calculator",
|
||||
srcs = ["text_to_tensor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_to_tensor_calculator_test",
|
||||
srcs = ["text_to_tensor_calculator_test.cc"],
|
||||
deps = [
|
||||
":text_to_tensor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_graph",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "universal_sentence_encoder_preprocessor_calculator",
|
||||
srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "universal_sentence_encoder_preprocessor_calculator_test",
|
||||
srcs = ["universal_sentence_encoder_preprocessor_calculator_test.cc"],
|
||||
data = ["//mediapipe/tasks/testdata/text:universal_sentence_encoder_qa"],
|
||||
deps = [
|
||||
":universal_sentence_encoder_preprocessor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "inference_calculator_proto",
|
||||
srcs = ["inference_calculator.proto"],
|
||||
|
@ -320,6 +507,8 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
|
251
mediapipe/calculators/tensor/bert_preprocessor_calculator.cc
Normal file
251
mediapipe/calculators/tensor/bert_preprocessor_calculator.cc
Normal file
|
@ -0,0 +1,251 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/calculators/tensor/bert_preprocessor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
constexpr int kNumInputTensorsForBert = 3;
|
||||
constexpr int kTokenizerProcessUnitIndex = 0;
|
||||
constexpr absl::string_view kInputIdsTensorName = "ids";
|
||||
constexpr absl::string_view kInputMasksTensorName = "mask";
|
||||
constexpr absl::string_view kSegmentIdsTensorName = "segment_ids";
|
||||
constexpr absl::string_view kClassifierToken = "[CLS]";
|
||||
constexpr absl::string_view kSeparatorToken = "[SEP]";
|
||||
|
||||
// Preprocesses input text into three int32 input tensors for a BERT model using
|
||||
// a tokenizer.
|
||||
// The associated BERT model is expected to contain input tensors with names:
|
||||
//
|
||||
// Tensor | Metadata Name
|
||||
// ---------------- | --------------
|
||||
// IDs | "ids"
|
||||
// Segment IDs | "segment_ids"
|
||||
// Mask | "mask"
|
||||
//
|
||||
// This calculator will return an error if the model does not have three input
|
||||
// tensors or if the tensors do not have names corresponding to the above
|
||||
// metadata names in some order. Additional details regarding these input
|
||||
// tensors are given in the Calculator "Outputs" section below.
|
||||
//
|
||||
// This calculator is currently configured for the TextClassifier Task but it
|
||||
// will eventually be generalized for other Text Tasks.
|
||||
// TODO: Handle preprocessing for other Text Tasks too.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// The input text.
|
||||
// Side Inputs:
|
||||
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||
// The metadata extractor for the BERT model. Used to determine the order of
|
||||
// the three input Tensors for the BERT model and to extract the metadata to
|
||||
// construct the tokenizer.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing the three input Tensors for the BERT model:
|
||||
// (1): the token ids of the tokenized input string. A classifier token
|
||||
// ("[CLS]") will be prepended to the input tokens and a separator
|
||||
// token ("[SEP]") will be appended to the input tokens.
|
||||
// (2): the segment ids, which are all 0 for now but will have different
|
||||
// values to distinguish between different sentences in the input
|
||||
// text for other Text tasks.
|
||||
// (3): the input mask ids, which are 1 at each of the input token indices
|
||||
// and 0 elsewhere.
|
||||
// The Tensors will have size equal to the max sequence length for the BERT
|
||||
// model.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "BertPreprocessorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// options {
|
||||
// [mediapipe.BertPreprocessorCalculatorOptions.ext] {
|
||||
// bert_max_seq_len: 128
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class BertPreprocessorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||
"METADATA_EXTRACTOR"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<tasks::text::tokenizers::Tokenizer> tokenizer_;
|
||||
// The max sequence length accepted by the BERT model.
|
||||
int bert_max_seq_len_ = 2;
|
||||
// Indices of the three input tensors for the BERT model. They should form the
|
||||
// set {0, 1, 2}.
|
||||
int input_ids_tensor_index_ = 0;
|
||||
int segment_ids_tensor_index_ = 1;
|
||||
int input_masks_tensor_index_ = 2;
|
||||
|
||||
// Applies `tokenizer_` to the `input_text` to generate a vector of tokens.
|
||||
// This util prepends "[CLS]" and appends "[SEP]" to the input tokens and
|
||||
// clips the vector of tokens to have length at most `bert_max_seq_len_`.
|
||||
std::vector<std::string> TokenizeInputText(absl::string_view input_text);
|
||||
// Processes the `input_tokens` to generate the three input tensors for the
|
||||
// BERT model.
|
||||
std::vector<Tensor> GenerateInputTensors(
|
||||
const std::vector<std::string>& input_tokens);
|
||||
};
|
||||
|
||||
absl::Status BertPreprocessorCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||
RET_CHECK(options.has_bert_max_seq_len()) << "bert_max_seq_len is required";
|
||||
RET_CHECK_GE(options.bert_max_seq_len(), 2)
|
||||
<< "bert_max_seq_len must be at least 2";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BertPreprocessorCalculator::Open(CalculatorContext* cc) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
&kMetadataExtractorSideIn(cc).Get();
|
||||
const tflite::ProcessUnit* tokenizer_metadata =
|
||||
metadata_extractor->GetInputProcessUnit(kTokenizerProcessUnitIndex);
|
||||
ASSIGN_OR_RETURN(tokenizer_,
|
||||
tasks::text::tokenizers::CreateTokenizerFromProcessUnit(
|
||||
tokenizer_metadata, metadata_extractor));
|
||||
|
||||
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
|
||||
input_ids_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kInputIdsTensorName);
|
||||
segment_ids_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kSegmentIdsTensorName);
|
||||
input_masks_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kInputMasksTensorName);
|
||||
absl::flat_hash_set<int> tensor_indices = {input_ids_tensor_index_,
|
||||
segment_ids_tensor_index_,
|
||||
input_masks_tensor_index_};
|
||||
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
|
||||
input_ids_tensor_index_, segment_ids_tensor_index_,
|
||||
input_masks_tensor_index_));
|
||||
}
|
||||
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::BertPreprocessorCalculatorOptions>();
|
||||
bert_max_seq_len_ = options.bert_max_seq_len();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status BertPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||
kTensorsOut(cc).Send(
|
||||
GenerateInputTensors(TokenizeInputText(kTextIn(cc).Get())));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::vector<std::string> BertPreprocessorCalculator::TokenizeInputText(
|
||||
absl::string_view input_text) {
|
||||
std::string processed_input = std::string(input_text);
|
||||
absl::AsciiStrToLower(&processed_input);
|
||||
|
||||
tasks::text::tokenizers::TokenizerResult tokenizer_result =
|
||||
tokenizer_->Tokenize(processed_input);
|
||||
|
||||
// Offset by 2 to account for [CLS] and [SEP]
|
||||
int input_tokens_size =
|
||||
std::min(bert_max_seq_len_,
|
||||
static_cast<int>(tokenizer_result.subwords.size()) + 2);
|
||||
std::vector<std::string> input_tokens;
|
||||
input_tokens.reserve(input_tokens_size);
|
||||
input_tokens.push_back(std::string(kClassifierToken));
|
||||
for (int i = 0; i < input_tokens_size - 2; ++i) {
|
||||
input_tokens.push_back(std::move(tokenizer_result.subwords[i]));
|
||||
}
|
||||
input_tokens.push_back(std::string(kSeparatorToken));
|
||||
return input_tokens;
|
||||
}
|
||||
|
||||
std::vector<Tensor> BertPreprocessorCalculator::GenerateInputTensors(
|
||||
const std::vector<std::string>& input_tokens) {
|
||||
std::vector<int32_t> input_ids(bert_max_seq_len_, 0);
|
||||
std::vector<int32_t> segment_ids(bert_max_seq_len_, 0);
|
||||
std::vector<int32_t> input_masks(bert_max_seq_len_, 0);
|
||||
// Convert tokens back into ids and set mask
|
||||
for (int i = 0; i < input_tokens.size(); ++i) {
|
||||
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
|
||||
input_masks[i] = 1;
|
||||
}
|
||||
// |<--------bert_max_seq_len_--------->|
|
||||
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
|
||||
// segment_ids 0 0 0... 0 0 0 0... 0
|
||||
// input_masks 1 1 1... 1 1 0 0... 0
|
||||
|
||||
std::vector<Tensor> input_tensors;
|
||||
input_tensors.reserve(kNumInputTensorsForBert);
|
||||
for (int i = 0; i < kNumInputTensorsForBert; ++i) {
|
||||
input_tensors.push_back(
|
||||
{Tensor::ElementType::kInt32, Tensor::Shape({bert_max_seq_len_})});
|
||||
}
|
||||
std::memcpy(input_tensors[input_ids_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<int32_t>(),
|
||||
input_ids.data(), input_ids.size() * sizeof(int32_t));
|
||||
std::memcpy(input_tensors[segment_ids_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<int32_t>(),
|
||||
segment_ids.data(), segment_ids.size() * sizeof(int32_t));
|
||||
std::memcpy(input_tensors[input_masks_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<int32_t>(),
|
||||
input_masks.data(), input_masks.size() * sizeof(int32_t));
|
||||
return input_tensors;
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(BertPreprocessorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,29 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message BertPreprocessorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional BertPreprocessorCalculatorOptions ext = 462509271;
|
||||
}
|
||||
|
||||
// The maximum input sequence length for the calculator's BERT model.
|
||||
optional int32 bert_max_seq_len = 1;
|
||||
}
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
constexpr int kNumInputTensorsForBert = 3;
|
||||
constexpr int kBertMaxSeqLen = 128;
|
||||
constexpr absl::string_view kTestModelPath =
|
||||
"mediapipe/tasks/testdata/text/bert_text_classifier.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<std::vector<int>>> RunBertPreprocessorCalculator(
|
||||
absl::string_view text, absl::string_view model_path) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "BertPreprocessorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
output_stream: "TENSORS:tensors"
|
||||
options {
|
||||
[mediapipe.BertPreprocessorCalculatorOptions.ext] {
|
||||
bert_max_seq_len: $0
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
kBertMaxSeqLen));
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
std::string model_buffer = tasks::core::LoadBinaryContent(model_path.data());
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
model_buffer.data(), model_buffer.size()));
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||
graph_config,
|
||||
{{"metadata_extractor",
|
||||
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != kNumInputTensorsForBert) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::Substitute("tensor_vec has size $0, expected $1",
|
||||
tensor_vec.size(), kNumInputTensorsForBert));
|
||||
}
|
||||
|
||||
std::vector<std::vector<int>> results;
|
||||
for (int i = 0; i < kNumInputTensorsForBert; i++) {
|
||||
const Tensor& tensor = tensor_vec[i];
|
||||
if (tensor.element_type() != Tensor::ElementType::kInt32) {
|
||||
return absl::InvalidArgumentError("Expected tensor element type kInt32");
|
||||
}
|
||||
auto* buffer = tensor.GetCpuReadView().buffer<int>();
|
||||
std::vector<int> buffer_view(buffer, buffer + kBertMaxSeqLen);
|
||||
results.push_back(buffer_view);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
|
||||
return results;
|
||||
}
|
||||
|
||||
TEST(BertPreprocessorCalculatorTest, TextClassifierWithBertModel) {
|
||||
std::vector<std::vector<int>> expected_result = {
|
||||
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 102}};
|
||||
// segment_ids
|
||||
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
|
||||
// input_masks
|
||||
expected_result.push_back(std::vector(expected_result[0].size(), 1));
|
||||
expected_result[2].resize(kBertMaxSeqLen);
|
||||
// padding input_ids
|
||||
expected_result[0].resize(kBertMaxSeqLen);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<std::vector<int>> processed_tensor_values,
|
||||
RunBertPreprocessorCalculator(
|
||||
"it's a charming and often affecting journey", kTestModelPath));
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
TEST(BertPreprocessorCalculatorTest, LongInput) {
|
||||
std::stringstream long_input;
|
||||
long_input
|
||||
<< "it's a charming and often affecting journey and this is a long";
|
||||
for (int i = 0; i < kBertMaxSeqLen; ++i) {
|
||||
long_input << " long";
|
||||
}
|
||||
long_input << " movie review";
|
||||
std::vector<std::vector<int>> expected_result = {
|
||||
{101, 2009, 1005, 1055, 1037, 11951, 1998, 2411, 12473, 4990, 1998, 2023,
|
||||
2003, 1037}};
|
||||
// "long" id
|
||||
expected_result[0].resize(kBertMaxSeqLen - 1, 2146);
|
||||
// "[SEP]" id
|
||||
expected_result[0].push_back(102);
|
||||
// segment_ids
|
||||
expected_result.push_back(std::vector(kBertMaxSeqLen, 0));
|
||||
// input_masks
|
||||
expected_result.push_back(std::vector(kBertMaxSeqLen, 1));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<std::vector<int>> processed_tensor_values,
|
||||
RunBertPreprocessorCalculator(long_input.str(), kTestModelPath));
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -243,8 +243,8 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
|
||||
ASSIGN_OR_RETURN(auto image, GetInputImage(cc));
|
||||
const Size size{image->width(), image->height()};
|
||||
RotatedRect roi = GetRoi(size.width, size.height, norm_rect);
|
||||
|
||||
RotatedRect roi = GetRoi(image->width(), image->height(), norm_rect);
|
||||
ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(),
|
||||
options_.output_tensor_height(),
|
||||
options_.keep_aspect_ratio(), &roi));
|
||||
|
@ -253,19 +253,22 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
if (kOutMatrix(cc).IsConnected()) {
|
||||
std::array<float, 16> matrix;
|
||||
GetRotatedSubRectToRectTransformMatrix(roi, size.width, size.height,
|
||||
/*flip_horizontaly=*/false,
|
||||
&matrix);
|
||||
GetRotatedSubRectToRectTransformMatrix(
|
||||
roi, image->width(), image->height(),
|
||||
/*flip_horizontaly=*/false, &matrix);
|
||||
kOutMatrix(cc).Send(std::move(matrix));
|
||||
}
|
||||
|
||||
// Lazy initialization of the GPU or CPU converter.
|
||||
MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, *image.get()));
|
||||
|
||||
ASSIGN_OR_RETURN(Tensor tensor,
|
||||
(image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
||||
->Convert(*image, roi, {output_width_, output_height_},
|
||||
range_min_, range_max_));
|
||||
Tensor::ElementType output_tensor_type =
|
||||
GetOutputTensorType(image->UsesGpu());
|
||||
Tensor tensor(output_tensor_type, {1, output_height_, output_width_,
|
||||
GetNumOutputChannels(*image)});
|
||||
MP_RETURN_IF_ERROR((image->UsesGpu() ? gpu_converter_ : cpu_converter_)
|
||||
->Convert(*image, roi, range_min_, range_max_,
|
||||
/*tensor_buffer_offset=*/0, tensor));
|
||||
|
||||
auto result = std::make_unique<std::vector<Tensor>>();
|
||||
result->push_back(std::move(tensor));
|
||||
|
@ -292,15 +295,31 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor::ElementType GetOutputTensorType() {
|
||||
if (is_float_output_) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
Tensor::ElementType GetOutputTensorType(bool uses_gpu) {
|
||||
if (!uses_gpu) {
|
||||
if (is_float_output_) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
if (range_min_ < 0) {
|
||||
return Tensor::ElementType::kInt8;
|
||||
} else {
|
||||
return Tensor::ElementType::kUInt8;
|
||||
}
|
||||
}
|
||||
if (range_min_ < 0) {
|
||||
return Tensor::ElementType::kInt8;
|
||||
} else {
|
||||
return Tensor::ElementType::kUInt8;
|
||||
// Always use float32 when GPU is enabled.
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
|
||||
int GetNumOutputChannels(const Image& image) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
if (image.UsesGpu()) {
|
||||
return 4;
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
// All of the processors except for Metal expect 3 channels.
|
||||
return 3;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
|
@ -366,7 +385,8 @@ class ImageToTensorCalculator : public Node {
|
|||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
ASSIGN_OR_RETURN(
|
||||
cpu_converter_,
|
||||
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
|
||||
CreateOpenCvConverter(cc, GetBorderMode(),
|
||||
GetOutputTensorType(/*uses_gpu=*/false)));
|
||||
#else
|
||||
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
|
||||
"MEDIAPIPE_DISABLE_OPENCV is defined.";
|
||||
|
|
|
@ -42,13 +42,16 @@ class ImageToTensorConverter {
|
|||
// @image contains image to extract from.
|
||||
// @roi describes region of interest within the image to extract (absolute
|
||||
// values).
|
||||
// @output_dims dimensions of output tensor.
|
||||
// @range_min/max describes output tensor range image pixels should converted
|
||||
// to.
|
||||
virtual absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims,
|
||||
float range_min, float range_max) = 0;
|
||||
// @tensor_buffer_offset an inteter representing the offset of the tensor
|
||||
// buffer the result should be written to.
|
||||
// @output_tensor a tensor with pre-defined shape. The "Convert" is
|
||||
// responsible of populating the content into the output tensor.
|
||||
virtual absl::Status Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi, float range_min,
|
||||
float range_max, int tensor_buffer_offset,
|
||||
Tensor& output_tensor) = 0;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -264,57 +264,58 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
});
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
constexpr int kNumChannels = 3;
|
||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
||||
{1, output_dims.height, output_dims.width, kNumChannels});
|
||||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||
range_max, tensor_buffer_offset]() -> absl::Status {
|
||||
const int input_num_channels = input.channels();
|
||||
auto source_texture = gl_helper_.CreateSourceTexture(input);
|
||||
tflite::gpu::gl::GlTexture input_texture(
|
||||
GL_TEXTURE_2D, source_texture.name(),
|
||||
input_num_channels == 4 ? GL_RGB : GL_RGBA,
|
||||
source_texture.width() * source_texture.height() *
|
||||
input_num_channels * sizeof(uint8_t),
|
||||
/*layer=*/0,
|
||||
/*owned=*/false);
|
||||
|
||||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi,
|
||||
&output_dims, range_min,
|
||||
range_max]() -> absl::Status {
|
||||
constexpr int kRgbaNumChannels = 4;
|
||||
auto source_texture = gl_helper_.CreateSourceTexture(input);
|
||||
tflite::gpu::gl::GlTexture input_texture(
|
||||
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
|
||||
source_texture.width() * source_texture.height() * kRgbaNumChannels *
|
||||
sizeof(uint8_t),
|
||||
/*layer=*/0,
|
||||
/*owned=*/false);
|
||||
constexpr float kInputImageRangeMin = 0.0f;
|
||||
constexpr float kInputImageRangeMax = 1.0f;
|
||||
ASSIGN_OR_RETURN(auto transform,
|
||||
GetValueRangeTransformation(kInputImageRangeMin,
|
||||
kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
|
||||
constexpr float kInputImageRangeMin = 0.0f;
|
||||
constexpr float kInputImageRangeMax = 1.0f;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto transform,
|
||||
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
const int output_size = output_tensor.bytes() / output_shape.dims[0];
|
||||
auto buffer_view = output_tensor.GetOpenGlBufferWriteView();
|
||||
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
|
||||
buffer_view.name(), output_size,
|
||||
/*offset=*/tensor_buffer_offset,
|
||||
/*has_ownership=*/false);
|
||||
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
|
||||
input_texture,
|
||||
tflite::gpu::HW(source_texture.height(), source_texture.width()),
|
||||
roi,
|
||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||
command_queue_.get(), &output));
|
||||
|
||||
auto buffer_view = tensor.GetOpenGlBufferWriteView();
|
||||
tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER,
|
||||
buffer_view.name(), tensor.bytes(),
|
||||
/*offset=*/0,
|
||||
/*has_ownership=*/false);
|
||||
MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer(
|
||||
input_texture,
|
||||
tflite::gpu::HW(source_texture.height(), source_texture.width()), roi,
|
||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||
tflite::gpu::HW(output_dims.height, output_dims.width),
|
||||
command_queue_.get(), &output));
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
~GlProcessor() override {
|
||||
|
@ -326,6 +327,17 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::unique_ptr<tflite::gpu::gl::CommandQueue> command_queue_;
|
||||
std::unique_ptr<SubRectExtractorGl> extractor_;
|
||||
mediapipe::GlCalculatorHelper gl_helper_;
|
||||
|
|
|
@ -168,26 +168,26 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
});
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
// TODO: support tensor_buffer_offset > 0 scenario.
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
constexpr int kNumChannels = 3;
|
||||
Tensor tensor(
|
||||
Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels});
|
||||
|
||||
MP_RETURN_IF_ERROR(
|
||||
gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims,
|
||||
range_min, range_max]() -> absl::Status {
|
||||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||
range_max]() -> absl::Status {
|
||||
auto input_texture = gl_helper_.CreateSourceTexture(input);
|
||||
|
||||
constexpr float kInputImageRangeMin = 0.0f;
|
||||
|
@ -196,27 +196,29 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
GetValueRangeTransformation(kInputImageRangeMin,
|
||||
kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
auto tensor_view = tensor.GetOpenGlTexture2dWriteView();
|
||||
auto tensor_view = output_tensor.GetOpenGlTexture2dWriteView();
|
||||
MP_RETURN_IF_ERROR(ExtractSubRect(input_texture, roi,
|
||||
/*flip_horizontaly=*/false,
|
||||
transform.scale, transform.offset,
|
||||
output_dims, &tensor_view));
|
||||
output_shape, &tensor_view));
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ExtractSubRect(const mediapipe::GlTexture& texture,
|
||||
const RotatedRect& sub_rect,
|
||||
bool flip_horizontaly, float alpha, float beta,
|
||||
const Size& output_dims,
|
||||
const Tensor::Shape& output_shape,
|
||||
Tensor::OpenGlTexture2dView* output) {
|
||||
const int output_height = output_shape.dims[1];
|
||||
const int output_width = output_shape.dims[2];
|
||||
std::array<float, 16> transform_mat;
|
||||
|
||||
glDisable(GL_DEPTH_TEST);
|
||||
glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_);
|
||||
glViewport(0, 0, output_dims.width, output_dims.height);
|
||||
glViewport(0, 0, output_width, output_height);
|
||||
|
||||
glActiveTexture(GL_TEXTURE0);
|
||||
glBindTexture(GL_TEXTURE_2D, output->name());
|
||||
|
@ -316,6 +318,17 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
mediapipe::GlCalculatorHelper gl_helper_;
|
||||
bool use_custom_zero_border_ = false;
|
||||
BorderMode border_mode_ = BorderMode::kReplicate;
|
||||
|
|
|
@ -262,7 +262,6 @@ class SubRectExtractorMetal {
|
|||
RET_CHECK(pipeline_state != nil);
|
||||
|
||||
std::string output_type_def;
|
||||
MTLPixelFormat pixel_format;
|
||||
switch (output_format) {
|
||||
case OutputFormat::kF16C4:
|
||||
output_type_def = R"(
|
||||
|
@ -348,10 +347,10 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
|
@ -359,16 +358,15 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
@autoreleasepool {
|
||||
id<MTLTexture> texture =
|
||||
[metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()];
|
||||
|
||||
constexpr int kNumChannels = 4;
|
||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, output_dims.height, output_dims.width,
|
||||
kNumChannels});
|
||||
|
||||
constexpr float kInputImageRangeMin = 0.0f;
|
||||
constexpr float kInputImageRangeMax = 1.0f;
|
||||
ASSIGN_OR_RETURN(
|
||||
|
@ -377,18 +375,30 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
range_min, range_max));
|
||||
|
||||
id<MTLCommandBuffer> command_buffer = [metal_helper_ commandBuffer];
|
||||
const auto& buffer_view = tensor.GetMtlBufferWriteView(command_buffer);
|
||||
const auto& buffer_view =
|
||||
output_tensor.GetMtlBufferWriteView(command_buffer);
|
||||
MP_RETURN_IF_ERROR(extractor_->Execute(
|
||||
texture, roi,
|
||||
/*flip_horizontaly=*/false, transform.scale, transform.offset,
|
||||
tflite::gpu::HW(output_dims.height, output_dims.width),
|
||||
tflite::gpu::HW(output_shape.dims[1], output_shape.dims[2]),
|
||||
command_buffer, buffer_view.buffer()));
|
||||
[command_buffer commit];
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 4)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MPPMetalHelper* metal_helper_ = nil;
|
||||
std::unique_ptr<SubRectExtractorMetal> extractor_;
|
||||
};
|
||||
|
|
|
@ -60,34 +60,39 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
absl::Status Convert(const mediapipe::Image& input, const RotatedRect& roi,
|
||||
float range_min, float range_max,
|
||||
int tensor_buffer_offset,
|
||||
Tensor& output_tensor) override {
|
||||
if (input.image_format() != mediapipe::ImageFormat::SRGB &&
|
||||
input.image_format() != mediapipe::ImageFormat::SRGBA) {
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Only RGBA/RGB formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.image_format())));
|
||||
}
|
||||
auto src = mediapipe::formats::MatView(&input);
|
||||
// TODO: Remove the check once tensor_buffer_offset > 0 is
|
||||
// supported.
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
<< "The non-zero tensor_buffer_offset input is not supported yet.";
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
||||
constexpr int kNumChannels = 3;
|
||||
Tensor tensor(tensor_type_, Tensor::Shape{1, output_dims.height,
|
||||
output_dims.width, kNumChannels});
|
||||
auto buffer_view = tensor.GetCpuWriteView();
|
||||
const int output_height = output_shape.dims[1];
|
||||
const int output_width = output_shape.dims[2];
|
||||
const int output_channels = output_shape.dims[3];
|
||||
auto buffer_view = output_tensor.GetCpuWriteView();
|
||||
cv::Mat dst;
|
||||
switch (tensor_type_) {
|
||||
case Tensor::ElementType::kInt8:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
buffer_view.buffer<int8>());
|
||||
break;
|
||||
case Tensor::ElementType::kFloat32:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
buffer_view.buffer<float>());
|
||||
break;
|
||||
case Tensor::ElementType::kUInt8:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
dst = cv::Mat(output_height, output_width, mat_type_,
|
||||
buffer_view.buffer<uint8>());
|
||||
break;
|
||||
default:
|
||||
|
@ -101,8 +106,8 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
cv::Mat src_points;
|
||||
cv::boxPoints(rotated_rect, src_points);
|
||||
|
||||
const float dst_width = output_dims.width;
|
||||
const float dst_height = output_dims.height;
|
||||
const float dst_width = output_width;
|
||||
const float dst_height = output_height;
|
||||
/* clang-format off */
|
||||
float dst_corners[8] = {0.0f, dst_height,
|
||||
0.0f, 0.0f,
|
||||
|
@ -110,6 +115,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
dst_width, dst_height};
|
||||
/* clang-format on */
|
||||
|
||||
auto src = mediapipe::formats::MatView(&input);
|
||||
cv::Mat dst_points = cv::Mat(4, 2, CV_32F, dst_corners);
|
||||
cv::Mat projection_matrix =
|
||||
cv::getPerspectiveTransform(src_points, dst_points);
|
||||
|
@ -119,7 +125,7 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
/*flags=*/cv::INTER_LINEAR,
|
||||
/*borderMode=*/border_mode_);
|
||||
|
||||
if (transformed.channels() > kNumChannels) {
|
||||
if (transformed.channels() > output_channels) {
|
||||
cv::Mat proper_channels_mat;
|
||||
cv::cvtColor(transformed, proper_channels_mat, cv::COLOR_RGBA2RGB);
|
||||
transformed = proper_channels_mat;
|
||||
|
@ -132,10 +138,21 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax,
|
||||
range_min, range_max));
|
||||
transformed.convertTo(dst, mat_type_, transform.scale, transform.offset);
|
||||
return tensor;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status ValidateTensorShape(const Tensor::Shape& output_shape) {
|
||||
RET_CHECK_EQ(output_shape.dims.size(), 4)
|
||||
<< "Wrong output dims size: " << output_shape.dims.size();
|
||||
RET_CHECK_EQ(output_shape.dims[0], 1)
|
||||
<< "Handling batch dimension not equal to 1 is not implemented in this "
|
||||
"converter.";
|
||||
RET_CHECK_EQ(output_shape.dims[3], 3)
|
||||
<< "Wrong output channel: " << output_shape.dims[3];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
enum cv::BorderTypes border_mode_;
|
||||
Tensor::ElementType tensor_type_;
|
||||
int mat_type_;
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
|
@ -191,7 +193,7 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
|
|||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors,
|
||||
std::vector<Tensor>& output_tensors) {
|
||||
return gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
// Explicitly copy input.
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
glBindBuffer(GL_COPY_READ_BUFFER,
|
||||
|
@ -203,7 +205,10 @@ absl::Status InferenceCalculatorGlImpl::GpuInferenceRunner::Process(
|
|||
}
|
||||
|
||||
// Run inference.
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
{
|
||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
|
||||
output_tensors.reserve(output_size_);
|
||||
for (int i = 0; i < output_size_; ++i) {
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
|
||||
#define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
|
@ -83,7 +85,7 @@ class InferenceCalculatorGlAdvancedImpl
|
|||
const mediapipe::InferenceCalculatorOptions::Delegate& delegate);
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> Process(
|
||||
const std::vector<Tensor>& input_tensors);
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors);
|
||||
|
||||
absl::Status Close();
|
||||
|
||||
|
@ -121,11 +123,11 @@ absl::Status InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Init(
|
|||
|
||||
absl::StatusOr<std::vector<Tensor>>
|
||||
InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
||||
const std::vector<Tensor>& input_tensors) {
|
||||
CalculatorContext* cc, const std::vector<Tensor>& input_tensors) {
|
||||
std::vector<Tensor> output_tensors;
|
||||
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
[this, cc, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
|
||||
input_tensors[i].GetOpenGlBufferReadView().name(), i));
|
||||
|
@ -138,7 +140,10 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
|||
output_tensors.back().GetOpenGlBufferWriteView().name(), i));
|
||||
}
|
||||
// Run inference.
|
||||
return tflite_gpu_runner_->Invoke();
|
||||
{
|
||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
||||
return tflite_gpu_runner_->Invoke();
|
||||
}
|
||||
}));
|
||||
|
||||
return output_tensors;
|
||||
|
@ -354,7 +359,7 @@ absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
|
|||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
|
||||
ASSIGN_OR_RETURN(*output_tensors,
|
||||
gpu_inference_runner_->Process(input_tensors));
|
||||
gpu_inference_runner_->Process(cc, input_tensors));
|
||||
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -224,9 +224,6 @@ absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
|
|||
|
||||
void InferenceCalculatorMetalImpl::AddDelegate(
|
||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
||||
const auto& calculator_opts =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||
|
||||
// Configure and create the delegate.
|
||||
TFLGpuDelegateOptions options;
|
||||
// `enable_quantization` enables the run of sparse models i.e. the models with
|
||||
|
|
|
@ -21,8 +21,10 @@
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -39,6 +41,19 @@ void CopyTensorBufferToInterpreter(const Tensor& input_tensor,
|
|||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||
}
|
||||
|
||||
template <>
|
||||
void CopyTensorBufferToInterpreter<char>(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
const char* input_tensor_buffer =
|
||||
input_tensor.GetCpuReadView().buffer<char>();
|
||||
tflite::DynamicBuffer dynamic_buffer;
|
||||
dynamic_buffer.AddString(input_tensor_buffer,
|
||||
input_tensor.shape().num_elements());
|
||||
dynamic_buffer.WriteToTensorAsVector(
|
||||
interpreter->tensor(interpreter->inputs()[input_tensor_index]));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBufferFromInterpreter(tflite::Interpreter* interpreter,
|
||||
int output_tensor_index,
|
||||
|
@ -87,13 +102,13 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteUInt8: {
|
||||
CopyTensorBufferToInterpreter<uint8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
CopyTensorBufferToInterpreter<uint8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt8: {
|
||||
CopyTensorBufferToInterpreter<int8>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
CopyTensorBufferToInterpreter<int8_t>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt32: {
|
||||
|
@ -101,6 +116,14 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteString: {
|
||||
CopyTensorBufferToInterpreter<char>(input_tensors[i],
|
||||
interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteBool:
|
||||
// No current use-case for copying MediaPipe Tensors with bool type to
|
||||
// TfLiteTensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type));
|
||||
|
@ -146,6 +169,15 @@ absl::StatusOr<std::vector<Tensor>> InferenceInterpreterDelegateRunner::Run(
|
|||
CopyTensorBufferFromInterpreter<int32_t>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteBool:
|
||||
output_tensors.emplace_back(Tensor::ElementType::kBool, shape,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
CopyTensorBufferFromInterpreter<bool>(interpreter_.get(), i,
|
||||
&output_tensors.back());
|
||||
break;
|
||||
case TfLiteType::kTfLiteString:
|
||||
// No current use-case for copying TfLiteTensors with string type to
|
||||
// MediaPipe Tensors.
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported output tensor type:",
|
||||
|
|
174
mediapipe/calculators/tensor/regex_preprocessor_calculator.cc
Normal file
174
mediapipe/calculators/tensor/regex_preprocessor_calculator.cc
Normal file
|
@ -0,0 +1,174 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/tensor/regex_preprocessor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/regex_tokenizer.h"
|
||||
#include "mediapipe/tasks/cc/text/tokenizers/tokenizer_utils.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
// Preprocesses input text into one int32 input tensor for a text model using
|
||||
// a RegexTokenizer.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// The input text.
|
||||
// Side Inputs:
|
||||
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||
// The metadata extractor for the text model. Used to extract the metadata
|
||||
// to construct the RegexTokenizer.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing a single Tensor which is the text model's input tensor.
|
||||
// Depending on the tokenizer metadata, the tensor may start with
|
||||
// the id of the tokenizer's <START> token. The following tensor values will
|
||||
// be the ids of the tokens of the input text. Any out-of-vocab tokens will
|
||||
// have the id of the <UNKNOWN> token. The tensor will be padded with the
|
||||
// <PAD> token id to have size equal to the max sequence length for the text
|
||||
// model.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "RegexPreprocessorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// options {
|
||||
// [mediapipe.RegexPreprocessorCalculatorOptions.ext] {
|
||||
// max_seq_len: 256
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class RegexPreprocessorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||
"METADATA_EXTRACTOR"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<tasks::text::tokenizers::RegexTokenizer> tokenizer_;
|
||||
// The max sequence length accepted by the text model.
|
||||
int max_seq_len_ = 0;
|
||||
};
|
||||
|
||||
absl::Status RegexPreprocessorCalculator::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
|
||||
RET_CHECK(options.has_max_seq_len()) << "max_seq_len is required";
|
||||
RET_CHECK_GT(options.max_seq_len(), 0) << "max_seq_len must be positive";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RegexPreprocessorCalculator::Open(CalculatorContext* cc) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
&kMetadataExtractorSideIn(cc).Get();
|
||||
const tflite::TensorMetadata* tensor_metadata =
|
||||
metadata_extractor->GetInputTensorMetadata(0);
|
||||
if (tensor_metadata == nullptr) {
|
||||
return absl::InvalidArgumentError("No tensor metadata found");
|
||||
}
|
||||
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* tokenizer_metadata,
|
||||
metadata_extractor->FindFirstProcessUnit(
|
||||
*tensor_metadata, tflite::ProcessUnitOptions_RegexTokenizerOptions));
|
||||
if (tokenizer_metadata == nullptr) {
|
||||
return absl::InvalidArgumentError("No tokenizer metadata found");
|
||||
}
|
||||
const tflite::RegexTokenizerOptions* regex_tokenizer_options =
|
||||
tokenizer_metadata->options_as<tflite::RegexTokenizerOptions>();
|
||||
ASSIGN_OR_RETURN(tokenizer_,
|
||||
tasks::text::tokenizers::CreateRegexTokenizerFromOptions(
|
||||
regex_tokenizer_options, metadata_extractor));
|
||||
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::RegexPreprocessorCalculatorOptions>();
|
||||
max_seq_len_ = options.max_seq_len();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status RegexPreprocessorCalculator::Process(CalculatorContext* cc) {
|
||||
tasks::text::tokenizers::TokenizerResult tokenizer_result =
|
||||
tokenizer_->Tokenize(kTextIn(cc).Get());
|
||||
|
||||
int unknown_token_id = 0;
|
||||
tokenizer_->GetUnknownToken(&unknown_token_id);
|
||||
int pad_token_id = 0;
|
||||
tokenizer_->GetPadToken(&pad_token_id);
|
||||
|
||||
std::vector<int> input_tokens(max_seq_len_, pad_token_id);
|
||||
int start_token_id = 0;
|
||||
int input_token_index = 0;
|
||||
if (tokenizer_->GetStartToken(&start_token_id)) {
|
||||
input_tokens[0] = start_token_id;
|
||||
input_token_index = 1;
|
||||
}
|
||||
|
||||
for (int i = 0; (i < tokenizer_result.subwords.size()) &&
|
||||
(input_token_index < max_seq_len_);
|
||||
++i, ++input_token_index) {
|
||||
const std::string& token = tokenizer_result.subwords[i];
|
||||
int token_id = 0;
|
||||
if (tokenizer_->LookupId(token, &token_id)) {
|
||||
input_tokens[input_token_index] = token_id;
|
||||
} else {
|
||||
input_tokens[input_token_index] = unknown_token_id;
|
||||
}
|
||||
}
|
||||
|
||||
// |<-------sentence_length-------->|
|
||||
// input_tensor <START>, t1, t2... <PAD>, <PAD>...
|
||||
// <START> is optional, t1, t2... will be replaced by <UNKNOWN> if it's
|
||||
// not found in the tokenizer vocab.
|
||||
std::vector<Tensor> result;
|
||||
result.push_back(
|
||||
{Tensor::ElementType::kInt32, Tensor::Shape({max_seq_len_})});
|
||||
std::memcpy(result[0].GetCpuWriteView().buffer<int32_t>(),
|
||||
input_tokens.data(), input_tokens.size() * sizeof(int32_t));
|
||||
kTensorsOut(cc).Send(std::move(result));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(RegexPreprocessorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,29 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message RegexPreprocessorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional RegexPreprocessorCalculatorOptions ext = 463716697;
|
||||
}
|
||||
|
||||
// The maximum input sequence length for the calculator's text model.
|
||||
optional int32 max_seq_len = 1;
|
||||
}
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/sink.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
constexpr int kMaxSeqLen = 256;
|
||||
constexpr char kTestModelPath[] =
|
||||
"mediapipe/tasks/testdata/text/"
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<int>> RunRegexPreprocessorCalculator(
|
||||
absl::string_view text) {
|
||||
auto graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
|
||||
R"pb(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "RegexPreprocessorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
output_stream: "TENSORS:tensors"
|
||||
options {
|
||||
[mediapipe.RegexPreprocessorCalculatorOptions.ext] {
|
||||
max_seq_len: $0
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb",
|
||||
kMaxSeqLen));
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
std::string model_buffer = tasks::core::LoadBinaryContent(kTestModelPath);
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
model_buffer.data(), model_buffer.size()));
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||
graph_config,
|
||||
{{"metadata_extractor",
|
||||
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor_vec has size $0, expected $1", tensor_vec.size(), 1));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kInt32) {
|
||||
return absl::InvalidArgumentError("Expected tensor element type kInt32");
|
||||
}
|
||||
auto* buffer = tensor_vec[0].GetCpuReadView().buffer<int>();
|
||||
std::vector<int> result(buffer, buffer + kMaxSeqLen);
|
||||
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
TEST(RegexPreprocessorCalculatorTest, TextClassifierModel) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<int> processed_tensor_values,
|
||||
RunRegexPreprocessorCalculator("This is the best movie I’ve seen in "
|
||||
"recent years. Strongly recommend it!"));
|
||||
static const int expected_result[kMaxSeqLen] = {
|
||||
1, 2, 9, 4, 118, 20, 2, 2, 110, 11, 1136, 153, 2, 386, 12};
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
TEST(RegexPreprocessorCalculatorTest, LongInput) {
|
||||
std::stringstream long_input;
|
||||
long_input << "This is the best";
|
||||
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||
long_input << " best";
|
||||
}
|
||||
long_input << "movie I’ve seen in recent years. Strongly recommend it!";
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::vector<int> processed_tensor_values,
|
||||
RunRegexPreprocessorCalculator(long_input.str()));
|
||||
std::vector<int> expected_result = {1, 2, 9, 4, 118};
|
||||
// "best" id
|
||||
expected_result.resize(kMaxSeqLen, 118);
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -296,7 +296,6 @@ absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) {
|
|||
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, height, width, channels});
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
|
||||
command_buffer.label = @"TensorConverterCalculatorConvert";
|
||||
id<MTLComputeCommandEncoder> compute_encoder =
|
||||
|
|
|
@ -87,6 +87,9 @@ absl::Status TensorsDequantizationCalculator::Process(CalculatorContext* cc) {
|
|||
case Tensor::ElementType::kInt8:
|
||||
Dequantize<int8>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
case Tensor::ElementType::kBool:
|
||||
Dequantize<bool>(input_tensor, &output_tensors->back());
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Unsupported input tensor type: ", input_tensor.element_type()));
|
||||
|
|
|
@ -124,5 +124,15 @@ TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithInt8Tensors) {
|
|||
ValidateResult(GetOutput(), {-1.007874, 0, 1});
|
||||
}
|
||||
|
||||
TEST_F(TensorsDequantizationCalculatorTest, SucceedsWithBoolTensors) {
|
||||
std::vector<bool> tensor = {true, false, true};
|
||||
PushTensor(Tensor::ElementType::kBool, tensor,
|
||||
Tensor::QuantizationParameters{1.0f, 0});
|
||||
|
||||
MP_ASSERT_OK(runner_.Run());
|
||||
|
||||
ValidateResult(GetOutput(), {1, 0, 1});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -163,6 +163,7 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
||||
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK_EQ(input_tensors.size(), 1);
|
||||
RET_CHECK(input_tensors[0].element_type() == Tensor::ElementType::kFloat32);
|
||||
|
@ -181,6 +182,12 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
|||
auto raw_scores = view.buffer<float>();
|
||||
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
if (options.has_tensor_index()) {
|
||||
classification_list->set_tensor_index(options.tensor_index());
|
||||
}
|
||||
if (options.has_tensor_name()) {
|
||||
classification_list->set_tensor_name(options.tensor_name());
|
||||
}
|
||||
if (is_binary_classification_) {
|
||||
Classification* class_first = classification_list->add_classification();
|
||||
Classification* class_second = classification_list->add_classification();
|
||||
|
|
|
@ -72,4 +72,9 @@ message TensorsToClassificationCalculatorOptions {
|
|||
// that are not in the `allow_classes` field will be completely ignored.
|
||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||
repeated int32 allow_classes = 8 [packed = true];
|
||||
|
||||
// The optional index of the tensor these classifications originate from.
|
||||
optional int32 tensor_index = 10;
|
||||
// The optional name of the tensor these classifications originate from.
|
||||
optional string tensor_name = 11;
|
||||
}
|
||||
|
|
|
@ -240,6 +240,36 @@ TEST_F(TensorsToClassificationCalculatorTest,
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
CorrectOutputWithTensorNameAndIndex) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "CLASSIFICATIONS:classifications"
|
||||
options {
|
||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
tensor_index: 1
|
||||
tensor_name: "foo"
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {0, 0.5, 1});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||
|
||||
EXPECT_EQ(1, output_packets_.size());
|
||||
|
||||
const auto& classification_list =
|
||||
output_packets_[0].Get<ClassificationList>();
|
||||
EXPECT_EQ(3, classification_list.classification_size());
|
||||
|
||||
// Verify that the tensor_index and tensor_name fields are correctly set.
|
||||
EXPECT_EQ(classification_list.tensor_index(), 1);
|
||||
EXPECT_EQ(classification_list.tensor_name(), "foo");
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
ClassNameAllowlistWithLabelItems) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
|
|
|
@ -532,7 +532,6 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
detection_classes.data(),
|
||||
output_detections));
|
||||
#elif MEDIAPIPE_METAL_ENABLED
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
if (!anchors_init_) {
|
||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
|
||||
|
|
|
@ -67,9 +67,7 @@ absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
|
|||
"tensor_vec has size $0, expected 1", tensor_vec.size()));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||
Tensor::ElementType::kChar));
|
||||
return absl::InvalidArgumentError("Expected tensor element type kChar");
|
||||
}
|
||||
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
|
||||
return std::string(buffer, text.length());
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::core::FindTensorIndexByMetadataName;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
|
||||
constexpr absl::string_view kQueryTextMetadataName = "inp_text";
|
||||
constexpr absl::string_view kResponseContextMetadataName = "res_context";
|
||||
constexpr absl::string_view kResponseTextMetadataName = "res_text";
|
||||
|
||||
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||
|
||||
// Preprocesses input text into three kTfLiteString input tensors for a
|
||||
// Universal Sentence Encoder (USE) model.
|
||||
//
|
||||
// The associated USE model is expected to contain input tensors with metadata
|
||||
// names:
|
||||
//
|
||||
// Tensor | Metadata Name
|
||||
// ---------------- | ------------------
|
||||
// Query text | "inp_text"
|
||||
// Response context | "res_context"
|
||||
// Response text | "res_text"
|
||||
//
|
||||
// This calculator will return an error if the model does not have three input
|
||||
// tensors or if the tensors do not have metadata names corresponding to the
|
||||
// above names in some order. Additional details regarding these input
|
||||
// tensors are given in the Calculator "Outputs" section below.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// The text to be embedded.
|
||||
// Side Inputs:
|
||||
// METADATA_EXTRACTOR - ModelMetadataExtractor
|
||||
// The metadata extractor for the USE model. Used to determine the order of
|
||||
// the three input Tensors for the USE model.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing the three input Tensors for the USE model. The tensors
|
||||
// fit a question-answering setting and store a query text, a response
|
||||
// context, and a response text. This calculator will just be preprocessing
|
||||
// a single input text that will be stored in the response text tensor. The
|
||||
// query text and response context tensors will store empty strings.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||
// input_stream: "TEXT:text"
|
||||
// input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// }
|
||||
class UniversalSentenceEncoderPreprocessorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::string> kTextIn{"TEXT"};
|
||||
static constexpr SideInput<ModelMetadataExtractor> kMetadataExtractorSideIn{
|
||||
"METADATA_EXTRACTOR"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kTextIn, kMetadataExtractorSideIn, kTensorsOut);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
// Indices of the three input tensors for the USE model. They should form the
|
||||
// set {0, 1, 2}.
|
||||
int query_text_tensor_index_ = 0;
|
||||
int response_context_tensor_index_ = 1;
|
||||
int response_text_tensor_index_ = 2;
|
||||
|
||||
// Tensor shapes for the model's input tensors.
|
||||
// The query text and response context tensors will only hold the empty
|
||||
// string, so their tensors will have shape [0], but the Universal Sentence
|
||||
// Encoder model's input signature requires them to be present. The response
|
||||
// text tensor will store the embedding text and have shape
|
||||
// [embedding_text_len].
|
||||
std::array<int, kNumInputTensorsForUniversalSentenceEncoder> tensor_shapes_;
|
||||
};
|
||||
|
||||
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Open(
|
||||
CalculatorContext* cc) {
|
||||
const ModelMetadataExtractor* metadata_extractor =
|
||||
&kMetadataExtractorSideIn(cc).Get();
|
||||
auto* input_tensors_metadata = metadata_extractor->GetInputTensorMetadata();
|
||||
query_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kQueryTextMetadataName);
|
||||
response_context_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kResponseContextMetadataName);
|
||||
response_text_tensor_index_ = FindTensorIndexByMetadataName(
|
||||
input_tensors_metadata, kResponseTextMetadataName);
|
||||
|
||||
absl::flat_hash_set<int> tensor_indices = absl::flat_hash_set<int>(
|
||||
{query_text_tensor_index_, response_context_tensor_index_,
|
||||
response_text_tensor_index_});
|
||||
if (tensor_indices != absl::flat_hash_set<int>({0, 1, 2})) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"Input tensor indices form the set {$0, $1, $2} rather than {0, 1, 2}",
|
||||
query_text_tensor_index_, response_context_tensor_index_,
|
||||
response_text_tensor_index_));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status UniversalSentenceEncoderPreprocessorCalculator::Process(
|
||||
CalculatorContext* cc) {
|
||||
absl::string_view text = kTextIn(cc).Get();
|
||||
const int text_len = static_cast<int>(text.length());
|
||||
tensor_shapes_[response_text_tensor_index_] = text_len;
|
||||
|
||||
std::vector<Tensor> input_tensors;
|
||||
input_tensors.reserve(kNumInputTensorsForUniversalSentenceEncoder);
|
||||
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||
input_tensors.push_back(
|
||||
{Tensor::ElementType::kChar, Tensor::Shape({tensor_shapes_[i]})});
|
||||
}
|
||||
|
||||
std::memcpy(
|
||||
input_tensors[query_text_tensor_index_].GetCpuWriteView().buffer<char>(),
|
||||
"", 0);
|
||||
std::memcpy(input_tensors[response_context_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<char>(),
|
||||
"", 0);
|
||||
std::memcpy(input_tensors[response_text_tensor_index_]
|
||||
.GetCpuWriteView()
|
||||
.buffer<char>(),
|
||||
text.data(), text_len * sizeof(char));
|
||||
kTensorsOut(cc).Send(std::move(input_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(UniversalSentenceEncoderPreprocessorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,109 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/options_map.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::IsOkAndHolds;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
constexpr int kNumInputTensorsForUniversalSentenceEncoder = 3;
|
||||
|
||||
constexpr absl::string_view kTestModelPath =
|
||||
"mediapipe/tasks/testdata/text/"
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<std::string>>
|
||||
RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "UniversalSentenceEncoderPreprocessorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
output_stream: "TENSORS:tensors"
|
||||
}
|
||||
)pb");
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
std::string model_buffer =
|
||||
tasks::core::LoadBinaryContent(kTestModelPath.data());
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
model_buffer.data(), model_buffer.size()));
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||
graph_config,
|
||||
{{"metadata_extractor",
|
||||
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != kNumInputTensorsForUniversalSentenceEncoder) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor_vec has size $0, expected $1", tensor_vec.size(),
|
||||
kNumInputTensorsForUniversalSentenceEncoder));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||
return absl::InvalidArgumentError("Expected tensor element type kChar");
|
||||
}
|
||||
std::vector<std::string> results;
|
||||
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||
results.push_back(
|
||||
{tensor_vec[i].GetCpuReadView().buffer<char>(),
|
||||
static_cast<size_t>(tensor_vec[i].shape().num_elements())});
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
TEST(UniversalSentenceEncoderPreprocessorCalculatorTest, TestUSE) {
|
||||
ASSERT_THAT(
|
||||
RunUniversalSentenceEncoderPreprocessorCalculator("test_input_text"),
|
||||
IsOkAndHolds(ElementsAreArray({"", "", "test_input_text"})));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -331,6 +331,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
|
|
@ -499,7 +499,6 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
|||
gpu_data_out_ = absl::make_unique<GPUData>();
|
||||
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
|
||||
const bool include_alpha = (max_num_channels_ == 4);
|
||||
const bool single_channel = (max_num_channels_ == 1);
|
||||
if (!(format == mediapipe::ImageFormat::GRAY8 ||
|
||||
format == mediapipe::ImageFormat::SRGB ||
|
||||
format == mediapipe::ImageFormat::SRGBA))
|
||||
|
@ -509,6 +508,7 @@ absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
|
|||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
const bool single_channel = (max_num_channels_ == 1);
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &include_alpha, &input, &single_channel]() -> absl::Status {
|
||||
// Device memory.
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -81,6 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
||||
#ifdef ABSL_HAVE_MMAP
|
||||
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
|
||||
const auto& model_fd =
|
||||
model_packet.Get<std::tuple<int, size_t, size_t>>();
|
||||
|
@ -89,6 +91,10 @@ class TfLiteModelCalculator : public CalculatorBase {
|
|||
tflite::DefaultErrorReporter());
|
||||
model = tflite::FlatBufferModel::BuildFromAllocation(
|
||||
std::move(model_allocation), tflite::DefaultErrorReporter());
|
||||
#else
|
||||
return absl::FailedPreconditionError(
|
||||
"Loading by file descriptor is not supported on this platform.");
|
||||
#endif
|
||||
}
|
||||
|
||||
RET_CHECK(model) << "Failed to load TfLite model from blob.";
|
||||
|
|
|
@ -143,9 +143,7 @@ mediapipe_proto_library(
|
|||
cc_library(
|
||||
name = "packet_frequency_calculator",
|
||||
srcs = ["packet_frequency_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:packet_frequency_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:packet_frequency_cc_proto",
|
||||
|
@ -190,9 +188,7 @@ cc_test(
|
|||
cc_library(
|
||||
name = "packet_latency_calculator",
|
||||
srcs = ["packet_latency_calculator.cc"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/util:latency_cc_proto",
|
||||
"//mediapipe/calculators/util:packet_latency_calculator_cc_proto",
|
||||
|
|
|
@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
|
|||
text->set_left(label_left_px_);
|
||||
text->set_baseline(label_baseline_px + i * label_height_px_);
|
||||
text->set_font_face(options_.font_face());
|
||||
if (options_.outline_thickness() > 0) {
|
||||
text->set_outline_thickness(options_.outline_thickness());
|
||||
if (options_.outline_color_size() > 0) {
|
||||
*(text->mutable_outline_color()) =
|
||||
options_.outline_color(i % options_.outline_color_size());
|
||||
} else {
|
||||
text->mutable_outline_color()->set_r(0);
|
||||
text->mutable_outline_color()->set_g(0);
|
||||
text->mutable_outline_color()->set_b(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
cc->Outputs()
|
||||
.Tag(kRenderDataTag)
|
||||
|
|
|
@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions {
|
|||
// Thickness for drawing the label(s).
|
||||
optional double thickness = 2 [default = 2];
|
||||
|
||||
// Color of outline around each character, if any. One per label, as with
|
||||
// color attribute.
|
||||
repeated Color outline_color = 12;
|
||||
|
||||
// Thickness of outline around each character.
|
||||
optional double outline_thickness = 11;
|
||||
|
||||
// The font height in absolute pixels.
|
||||
optional int32 font_height_px = 3 [default = 50];
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import android.content.ClipDescription;
|
|||
import android.content.Context;
|
||||
import android.net.Uri;
|
||||
import android.os.Bundle;
|
||||
import androidx.appcompat.widget.AppCompatEditText;
|
||||
import android.support.v7.widget.AppCompatEditText;
|
||||
import android.util.AttributeSet;
|
||||
import android.util.Log;
|
||||
import android.view.inputmethod.EditorInfo;
|
||||
|
|
|
@ -1685,10 +1685,3 @@ cc_test(
|
|||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -14,15 +14,10 @@ cc_library(
|
|||
name = "builder",
|
||||
hdrs = ["builder.h"],
|
||||
deps = [
|
||||
":const_str",
|
||||
":contract",
|
||||
":node",
|
||||
":packet",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_base",
|
||||
"//mediapipe/framework:calculator_contract",
|
||||
"@com_google_absl//absl/container:btree",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -5,12 +5,7 @@
|
|||
#include <type_traits>
|
||||
|
||||
#include "absl/container/btree_map.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/const_str.h"
|
||||
#include "mediapipe/framework/api2/contract.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_base.h"
|
||||
#include "mediapipe/framework/calculator_contract.h"
|
||||
|
@ -112,6 +107,17 @@ class MultiPort : public Single {
|
|||
std::vector<std::unique_ptr<Base>>& vec_;
|
||||
};
|
||||
|
||||
namespace internal_builder {
|
||||
|
||||
template <typename T, typename U>
|
||||
using AllowCast = std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>>;
|
||||
|
||||
} // namespace internal_builder
|
||||
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
class SourceImpl;
|
||||
|
||||
// These classes wrap references to the underlying source/destination
|
||||
// endpoints, adding type information and the user-visible API.
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
|
@ -122,16 +128,21 @@ class DestinationImpl {
|
|||
explicit DestinationImpl(std::vector<std::unique_ptr<Base>>* vec)
|
||||
: DestinationImpl(&GetWithAutoGrow(vec, 0)) {}
|
||||
explicit DestinationImpl(DestinationBase* base) : base_(*base) {}
|
||||
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
DestinationImpl<IsSide, U> Cast() {
|
||||
return DestinationImpl<IsSide, U>(&base_);
|
||||
}
|
||||
|
||||
private:
|
||||
DestinationBase& base_;
|
||||
|
||||
template <bool Source_IsSide, typename Source_T>
|
||||
friend class SourceImpl;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class MultiDestinationImpl : public MultiPort<DestinationImpl<IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<DestinationImpl<IsSide, T>>::MultiPort;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T = internal::Generic>
|
||||
class SourceImpl {
|
||||
public:
|
||||
using Base = SourceBase;
|
||||
|
@ -171,12 +182,8 @@ class SourceImpl {
|
|||
return AddTarget(dest);
|
||||
}
|
||||
|
||||
template <typename U>
|
||||
struct AllowCast
|
||||
: public std::integral_constant<bool, std::is_same_v<T, AnyType> &&
|
||||
!std::is_same_v<T, U>> {};
|
||||
|
||||
template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0>
|
||||
template <typename U,
|
||||
std::enable_if_t<internal_builder::AllowCast<T, U>{}, int> = 0>
|
||||
SourceImpl<IsSide, U> Cast() {
|
||||
return SourceImpl<IsSide, U>(base_);
|
||||
}
|
||||
|
@ -186,12 +193,6 @@ class SourceImpl {
|
|||
SourceBase* base_;
|
||||
};
|
||||
|
||||
template <bool IsSide, typename T>
|
||||
class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
|
||||
public:
|
||||
using MultiPort<SourceImpl<IsSide, T>>::MultiPort;
|
||||
};
|
||||
|
||||
// A source and a destination correspond to an output/input stream on a node,
|
||||
// and a side source and side destination correspond to an output/input side
|
||||
// packet.
|
||||
|
@ -201,20 +202,20 @@ class MultiSourceImpl : public MultiPort<SourceImpl<IsSide, T>> {
|
|||
template <typename T = internal::Generic>
|
||||
using Source = SourceImpl<false, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSource = MultiSourceImpl<false, T>;
|
||||
using MultiSource = MultiPort<Source<T>>;
|
||||
template <typename T = internal::Generic>
|
||||
using SideSource = SourceImpl<true, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSideSource = MultiSourceImpl<true, T>;
|
||||
using MultiSideSource = MultiPort<SideSource<T>>;
|
||||
|
||||
template <typename T = internal::Generic>
|
||||
using Destination = DestinationImpl<false, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using SideDestination = DestinationImpl<true, T>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiDestination = MultiDestinationImpl<false, T>;
|
||||
using MultiDestination = MultiPort<Destination<T>>;
|
||||
template <typename T = internal::Generic>
|
||||
using MultiSideDestination = MultiDestinationImpl<true, T>;
|
||||
using MultiSideDestination = MultiPort<SideDestination<T>>;
|
||||
|
||||
class NodeBase {
|
||||
public:
|
||||
|
@ -439,8 +440,9 @@ class Graph {
|
|||
// Creates a node of a specific type. Should be used for pure interfaces,
|
||||
// which do not have a built-in type string.
|
||||
template <class Calc>
|
||||
Node<Calc>& AddNode(const std::string& type) {
|
||||
auto node = std::make_unique<Node<Calc>>(type);
|
||||
Node<Calc>& AddNode(absl::string_view type) {
|
||||
auto node =
|
||||
std::make_unique<Node<Calc>>(std::string(type.data(), type.size()));
|
||||
auto node_p = node.get();
|
||||
nodes_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
|
@ -448,16 +450,18 @@ class Graph {
|
|||
|
||||
// Creates a generic node, with no compile-time checking of inputs and
|
||||
// outputs. This can be used for calculators whose contract is not visible.
|
||||
GenericNode& AddNode(const std::string& type) {
|
||||
auto node = std::make_unique<GenericNode>(type);
|
||||
GenericNode& AddNode(absl::string_view type) {
|
||||
auto node =
|
||||
std::make_unique<GenericNode>(std::string(type.data(), type.size()));
|
||||
auto node_p = node.get();
|
||||
nodes_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
}
|
||||
|
||||
// For legacy PacketGenerators.
|
||||
PacketGenerator& AddPacketGenerator(const std::string& type) {
|
||||
auto node = std::make_unique<PacketGenerator>(type);
|
||||
PacketGenerator& AddPacketGenerator(absl::string_view type) {
|
||||
auto node = std::make_unique<PacketGenerator>(
|
||||
std::string(type.data(), type.size()));
|
||||
auto node_p = node.get();
|
||||
packet_gens_.emplace_back(std::move(node));
|
||||
return *node_p;
|
||||
|
|
|
@ -430,6 +430,8 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||
any_type_output.SetName("any_type_output");
|
||||
|
||||
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
|
@ -438,6 +440,7 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
output_stream: "ANY_OUTPUT:any_type_output"
|
||||
}
|
||||
input_stream: "GRAPH_ANY_INPUT:__stream_0"
|
||||
output_stream: "GRAPH_ANY_OUTPUT:any_type_output"
|
||||
)pb");
|
||||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
|
|
@ -185,7 +185,7 @@ class CalculatorBaseFactory {
|
|||
// Functions for checking that the calculator has the required GetContract.
|
||||
template <class T>
|
||||
constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) {
|
||||
typedef absl::Status (*GetContractType)(CalculatorContract * cc);
|
||||
typedef absl::Status (*GetContractType)(CalculatorContract* cc);
|
||||
return std::is_same<decltype(&T::GetContract), GetContractType>::value;
|
||||
}
|
||||
template <class T>
|
||||
|
|
|
@ -133,7 +133,12 @@ message GraphTrace {
|
|||
TPU_TASK = 13;
|
||||
GPU_CALIBRATION = 14;
|
||||
PACKET_QUEUED = 15;
|
||||
GPU_TASK_INVOKE = 16;
|
||||
TPU_TASK_INVOKE = 17;
|
||||
}
|
||||
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
|
||||
// //depot/mediapipe/framework/profiler/trace_buffer.h:event_type_list,
|
||||
// )
|
||||
|
||||
// The timing for one packet set being processed at one caclulator node.
|
||||
message CalculatorTrace {
|
||||
|
|
|
@ -334,13 +334,6 @@ mediapipe_register_type(
|
|||
deps = [":landmark_cc_proto"],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "image",
|
||||
srcs = ["image.cc"],
|
||||
|
@ -469,6 +462,10 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
defines = select({
|
||||
"//mediapipe/framework:android_no_jni": ["MEDIAPIPE_NO_JNI"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkopts = select({
|
||||
"//mediapipe:ios": [
|
||||
"-framework CoreVideo",
|
||||
|
|
|
@ -33,10 +33,3 @@ mediapipe_proto_library(
|
|||
srcs = ["rasterization.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Expose the proto source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "protos_src",
|
||||
srcs = glob(["*.proto"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -37,6 +37,10 @@ message Classification {
|
|||
// Group of Classification protos.
|
||||
message ClassificationList {
|
||||
repeated Classification classification = 1;
|
||||
// Optional index of the tensor that produced these classifications.
|
||||
optional int32 tensor_index = 2;
|
||||
// Optional name of the tensor that produced these classifications.
|
||||
optional string tensor_name = 3;
|
||||
}
|
||||
|
||||
// Group of ClassificationList protos.
|
||||
|
|
|
@ -31,11 +31,12 @@
|
|||
#if MEDIAPIPE_METAL_ENABLED
|
||||
#import <Metal/Metal.h>
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
#ifndef MEDIAPIPE_NO_JNI
|
||||
#if __ANDROID_API__ >= 26 || defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
|
||||
#define MEDIAPIPE_TENSOR_USE_AHWB 1
|
||||
#endif // __ANDROID_API__ >= 26 ||
|
||||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__)
|
||||
#endif // MEDIAPIPE_NO_JNI
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#include <android/hardware_buffer.h>
|
||||
|
@ -43,7 +44,6 @@
|
|||
#include "third_party/GL/gl/include/EGL/egl.h"
|
||||
#include "third_party/GL/gl/include/EGL/eglext.h"
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "mediapipe/gpu/gl_context.h"
|
||||
|
@ -97,8 +97,8 @@ class Tensor {
|
|||
kUInt8,
|
||||
kInt8,
|
||||
kInt32,
|
||||
// TODO: Update the inference runner to handle kTfLiteString.
|
||||
kChar
|
||||
kChar,
|
||||
kBool
|
||||
};
|
||||
struct Shape {
|
||||
Shape() = default;
|
||||
|
@ -330,6 +330,8 @@ class Tensor {
|
|||
return sizeof(int32_t);
|
||||
case ElementType::kChar:
|
||||
return sizeof(char);
|
||||
case ElementType::kBool:
|
||||
return sizeof(bool);
|
||||
}
|
||||
}
|
||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||
|
|
|
@ -371,7 +371,7 @@ void* Tensor::MapAhwbToCpuRead() const {
|
|||
if ((valid_ & kValidOpenGlBuffer) && ssbo_written_ == -1) {
|
||||
// EGLSync is failed. Use another synchronization method.
|
||||
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
|
||||
glFinish();
|
||||
gl_context_->Run([]() { glFinish(); });
|
||||
} else if (valid_ & kValidAHardwareBuffer) {
|
||||
CHECK(ahwb_written_) << "Ahwb-to-Cpu synchronization requires the "
|
||||
"completion function to be set";
|
||||
|
|
|
@ -29,6 +29,9 @@ TEST(General, TestDataTypes) {
|
|||
|
||||
Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4});
|
||||
EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char));
|
||||
|
||||
Tensor t_bool(Tensor::ElementType::kBool, Tensor::Shape{2, 3});
|
||||
EXPECT_EQ(t_bool.bytes(), t_bool.shape().num_elements() * sizeof(bool));
|
||||
}
|
||||
|
||||
TEST(Cpu, TestMemoryAllocation) {
|
||||
|
|
|
@ -109,6 +109,11 @@ struct TraceEvent {
|
|||
static constexpr EventType TPU_TASK = GraphTrace::TPU_TASK;
|
||||
static constexpr EventType GPU_CALIBRATION = GraphTrace::GPU_CALIBRATION;
|
||||
static constexpr EventType PACKET_QUEUED = GraphTrace::PACKET_QUEUED;
|
||||
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
|
||||
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
|
||||
// //depot/mediapipe/framework/mediapipe_profiling.h:profiler_census_tags,
|
||||
// //depot/mediapipe/framework/calculator_profile.proto:event_type,
|
||||
// )
|
||||
};
|
||||
|
||||
// Packet trace log buffer.
|
||||
|
|
|
@ -64,7 +64,7 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
|
|||
std::vector<TraceEventType> basic_types = {
|
||||
{TraceEvent::UNKNOWN, "An uninitialized trace-event."},
|
||||
{TraceEvent::OPEN, "A call to Calculator::Open.", true, true},
|
||||
{TraceEvent::PROCESS, "A call to Calculator::Open.", true, true},
|
||||
{TraceEvent::PROCESS, "A call to Calculator::Process.", true, true},
|
||||
{TraceEvent::CLOSE, "A call to Calculator::Close.", true, true},
|
||||
|
||||
{TraceEvent::NOT_READY, "A calculator cannot process packets yet."},
|
||||
|
|
|
@ -150,7 +150,7 @@ cc_library(
|
|||
name = "executor_util",
|
||||
srcs = ["executor_util.cc"],
|
||||
hdrs = ["executor_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
|
|
|
@ -378,8 +378,11 @@ cc_library(
|
|||
],
|
||||
}),
|
||||
deps = [
|
||||
":gl_texture_buffer",
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage",
|
||||
":image_frame_view",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
@ -1050,7 +1053,7 @@ objc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
MIN_IOS_VERSION = "9.0" # For thread_local.
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
test_suite(
|
||||
name = "ios",
|
||||
|
|
|
@ -111,7 +111,8 @@ typedef CVOpenGLESTextureCacheRef CVTextureCacheType;
|
|||
- (CVMetalTextureCacheRef)mtlTextureCache {
|
||||
@synchronized(self) {
|
||||
if (!_mtlTextureCache) {
|
||||
CVReturn err = CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
|
||||
CVReturn __unused err =
|
||||
CVMetalTextureCacheCreate(NULL, NULL, self.mtlDevice, NULL, &_mtlTextureCache);
|
||||
NSAssert(err == kCVReturnSuccess, @"Error at CVMetalTextureCacheCreate %d", err);
|
||||
// TODO: register and flush metal caches too.
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ static void EglThreadExitCallback(void* key_value) {
|
|||
// implementations, and should be considered as an undocumented vendor
|
||||
// extension.
|
||||
// https://www.khronos.org/registry/EGL/sdk/docs/man/html/eglMakeCurrent.xhtml
|
||||
//
|
||||
// NOTE: crashes on some Android devices (occurs with libGLES_meow.so).
|
||||
eglMakeCurrent(eglGetDisplay(EGL_DEFAULT_DISPLAY), EGL_NO_SURFACE,
|
||||
EGL_NO_SURFACE, EGL_NO_CONTEXT);
|
||||
#endif
|
||||
|
|
|
@ -78,7 +78,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) {
|
|||
16,
|
||||
0};
|
||||
|
||||
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs];
|
||||
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_2_1];
|
||||
}
|
||||
if (!pixel_format_) {
|
||||
// On several Forge machines, the default config fails. For now let's do
|
||||
|
|
|
@ -144,14 +144,23 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) {
|
|||
context](std::shared_ptr<GlSyncPoint> sync_token) {
|
||||
CHECK_NE(name_, 0);
|
||||
GLuint name_to_delete = name_;
|
||||
context->RunWithoutWaiting([name_to_delete, sync_token]() {
|
||||
if (sync_token) {
|
||||
// TODO: maybe we do not actually have to wait for the
|
||||
// consumer sync here. Check docs.
|
||||
sync_token->WaitOnGpu();
|
||||
} else {
|
||||
LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback";
|
||||
}
|
||||
context->RunWithoutWaiting([name_to_delete]() {
|
||||
// Note that we do not wait for consumers to be done before deleting the
|
||||
// texture. Based on a reading of the GLES 3.0 spec, appendix D:
|
||||
// - when a texture is deleted, it is _not_ automatically unbound from
|
||||
// bind points in other contexts;
|
||||
// - when a texture is deleted, its name becomes immediately invalid, but
|
||||
// the actual object is not deleted until it is no longer in use, i.e.
|
||||
// attached to a container object or bound to a context;
|
||||
// - deleting an object is not an operation that changes its contents;
|
||||
// - within each context, commands are executed sequentially, so it seems
|
||||
// like an unbind that follows a command that reads a texture should not
|
||||
// take effect until the GPU has actually finished executing the
|
||||
// previous commands.
|
||||
// The final point is the least explicit in the docs, but it is implied by
|
||||
// normal single-context behavior. E.g. if you do bind, delete, render,
|
||||
// unbind, the object is not deleted until the unbind, and it waits for
|
||||
// the render to finish.
|
||||
DLOG_IF(ERROR, !glIsTexture(name_to_delete))
|
||||
<< "Deleting invalid texture id: " << name_to_delete;
|
||||
glDeleteTextures(1, &name_to_delete);
|
||||
|
@ -185,7 +194,10 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
|
|||
<< "Updated existing texture which had not been marked for reuse!";
|
||||
CHECK(prod_token);
|
||||
producer_sync_ = std::move(prod_token);
|
||||
producer_context_ = producer_sync_->GetContext();
|
||||
const auto& synced_context = producer_sync_->GetContext();
|
||||
if (synced_context) {
|
||||
producer_context_ = synced_context;
|
||||
}
|
||||
}
|
||||
|
||||
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {
|
||||
|
|
|
@ -65,6 +65,7 @@ class GlTextureView {
|
|||
friend class GpuBuffer;
|
||||
friend class GlTextureBuffer;
|
||||
friend class GpuBufferStorageCvPixelBuffer;
|
||||
friend class GpuBufferStorageAhwb;
|
||||
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
|
||||
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
|
||||
DetachFn detach, DoneWritingFn done_writing)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h"
|
||||
#include "mediapipe/gpu/gpu_test_base.h"
|
||||
#include "stb_image.h"
|
||||
|
|
|
@ -17,8 +17,8 @@ package com.google.mediapipe.framework;
|
|||
import android.graphics.Bitmap;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.ByteBufferExtractor;
|
||||
import com.google.mediapipe.framework.image.Image;
|
||||
import com.google.mediapipe.framework.image.ImageProperties;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.framework.image.MPImageProperties;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
// TODO: use Preconditions in this file.
|
||||
|
@ -60,24 +60,24 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
}
|
||||
|
||||
/**
|
||||
* Creates an Image packet from an {@link Image}.
|
||||
* Creates a MediaPipe Image packet from a {@link MPImage}.
|
||||
*
|
||||
* <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP.
|
||||
*/
|
||||
public Packet createImage(Image image) {
|
||||
public Packet createImage(MPImage image) {
|
||||
// TODO: Choose the best storage from multiple containers.
|
||||
ImageProperties properties = image.getContainedImageProperties().get(0);
|
||||
if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) {
|
||||
MPImageProperties properties = image.getContainedImageProperties().get(0);
|
||||
if (properties.getStorageType() == MPImage.STORAGE_TYPE_BYTEBUFFER) {
|
||||
ByteBuffer buffer = ByteBufferExtractor.extract(image);
|
||||
int numChannels = 0;
|
||||
switch (properties.getImageFormat()) {
|
||||
case Image.IMAGE_FORMAT_RGBA:
|
||||
case MPImage.IMAGE_FORMAT_RGBA:
|
||||
numChannels = 4;
|
||||
break;
|
||||
case Image.IMAGE_FORMAT_RGB:
|
||||
case MPImage.IMAGE_FORMAT_RGB:
|
||||
numChannels = 3;
|
||||
break;
|
||||
case Image.IMAGE_FORMAT_ALPHA:
|
||||
case MPImage.IMAGE_FORMAT_ALPHA:
|
||||
numChannels = 1;
|
||||
break;
|
||||
default: // fall out
|
||||
|
@ -90,7 +90,7 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
int height = image.getHeight();
|
||||
return createImage(buffer, width, height, numChannels);
|
||||
}
|
||||
if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) {
|
||||
if (properties.getImageFormat() == MPImage.STORAGE_TYPE_BITMAP) {
|
||||
Bitmap bitmap = BitmapExtractor.extract(image);
|
||||
if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) {
|
||||
throw new UnsupportedOperationException("bitmap must use ARGB_8888 config.");
|
||||
|
|
|
@ -30,3 +30,10 @@ android_library(
|
|||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
# Expose the java source files for building mediapipe AAR.
|
||||
filegroup(
|
||||
name = "java_src",
|
||||
srcs = glob(["*.java"]),
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
|
|
@ -18,29 +18,29 @@ package com.google.mediapipe.framework.image;
|
|||
import android.graphics.Bitmap;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link android.graphics.Bitmap} from {@link Image}.
|
||||
* Utility for extracting {@link android.graphics.Bitmap} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BITMAP}, otherwise
|
||||
* {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
public final class BitmapExtractor {
|
||||
|
||||
/**
|
||||
* Extracts a {@link android.graphics.Bitmap} from an {@link Image}.
|
||||
* Extracts a {@link android.graphics.Bitmap} from a {@link MPImage}.
|
||||
*
|
||||
* @param image the image to extract {@link android.graphics.Bitmap} from.
|
||||
* @return the {@link android.graphics.Bitmap} stored in {@link Image}
|
||||
* @return the {@link android.graphics.Bitmap} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
public static Bitmap extract(Image image) {
|
||||
ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP);
|
||||
public static Bitmap extract(MPImage image) {
|
||||
MPImageContainer imageContainer = image.getContainer(MPImage.STORAGE_TYPE_BITMAP);
|
||||
if (imageContainer != null) {
|
||||
return ((BitmapImageContainer) imageContainer).getBitmap();
|
||||
} else {
|
||||
// TODO: Support ByteBuffer -> Bitmap conversion.
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting Bitmap from an Image created by objects other than Bitmap is not"
|
||||
"Extracting Bitmap from a MPImage created by objects other than Bitmap is not"
|
||||
+ " supported");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import android.provider.MediaStore;
|
|||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Builds {@link Image} from {@link android.graphics.Bitmap}.
|
||||
* Builds {@link MPImage} from {@link android.graphics.Bitmap}.
|
||||
*
|
||||
* <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once
|
||||
* {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content
|
||||
|
@ -49,7 +49,7 @@ public class BitmapImageBuilder {
|
|||
}
|
||||
|
||||
/**
|
||||
* Creates the builder to build {@link Image} from a file.
|
||||
* Creates the builder to build {@link MPImage} from a file.
|
||||
*
|
||||
* @param context the application context.
|
||||
* @param uri the path to the resource file.
|
||||
|
@ -58,15 +58,15 @@ public class BitmapImageBuilder {
|
|||
this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri));
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
BitmapImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(
|
||||
new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,19 +16,19 @@ limitations under the License.
|
|||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
|
||||
class BitmapImageContainer implements ImageContainer {
|
||||
class BitmapImageContainer implements MPImageContainer {
|
||||
|
||||
private final Bitmap bitmap;
|
||||
private final ImageProperties properties;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public BitmapImageContainer(Bitmap bitmap) {
|
||||
this.bitmap = bitmap;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
MPImageProperties.builder()
|
||||
.setImageFormat(convertFormatCode(bitmap.getConfig()))
|
||||
.setStorageType(Image.STORAGE_TYPE_BITMAP)
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BITMAP)
|
||||
.build();
|
||||
}
|
||||
|
||||
|
@ -37,7 +37,7 @@ class BitmapImageContainer implements ImageContainer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
|
@ -46,15 +46,15 @@ class BitmapImageContainer implements ImageContainer {
|
|||
bitmap.recycle();
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
static int convertFormatCode(Bitmap.Config config) {
|
||||
switch (config) {
|
||||
case ALPHA_8:
|
||||
return Image.IMAGE_FORMAT_ALPHA;
|
||||
return MPImage.IMAGE_FORMAT_ALPHA;
|
||||
case ARGB_8888:
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
default:
|
||||
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||
return MPImage.IMAGE_FORMAT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,45 +21,45 @@ import android.graphics.Bitmap.Config;
|
|||
import android.os.Build.VERSION;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link ByteBuffer} from {@link Image}.
|
||||
* Utility for extracting {@link ByteBuffer} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise
|
||||
* {@link IllegalArgumentException} will be thrown.
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_BYTEBUFFER},
|
||||
* otherwise {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
public class ByteBufferExtractor {
|
||||
|
||||
/**
|
||||
* Extracts a {@link ByteBuffer} from an {@link Image}.
|
||||
* Extracts a {@link ByteBuffer} from a {@link MPImage}.
|
||||
*
|
||||
* <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link
|
||||
* ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}.
|
||||
* MPImageProperties} whose storage type is {@code MPImage.STORAGE_TYPE_BYTEBUFFER}.
|
||||
*
|
||||
* @see Image#getContainedImageProperties()
|
||||
* @see MPImage#getContainedImageProperties()
|
||||
* @return A read-only {@link ByteBuffer}.
|
||||
* @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage.
|
||||
*/
|
||||
@SuppressLint("SwitchIntDef")
|
||||
public static ByteBuffer extract(Image image) {
|
||||
ImageContainer container = image.getContainer();
|
||||
public static ByteBuffer extract(MPImage image) {
|
||||
MPImageContainer container = image.getContainer();
|
||||
switch (container.getImageProperties().getStorageType()) {
|
||||
case Image.STORAGE_TYPE_BYTEBUFFER:
|
||||
case MPImage.STORAGE_TYPE_BYTEBUFFER:
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from an Image created by objects other than Bytebuffer is not"
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
|
||||
+ " supported");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}.
|
||||
* Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from a {@link MPImage}.
|
||||
*
|
||||
* <p>Format conversion spec:
|
||||
*
|
||||
|
@ -70,26 +70,26 @@ public class ByteBufferExtractor {
|
|||
*
|
||||
* @param image the image to extract buffer from.
|
||||
* @param targetFormat the image format of the result bytebuffer.
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
static ByteBuffer extract(Image image, @ImageFormat int targetFormat) {
|
||||
ImageContainer container;
|
||||
ImageProperties byteBufferProperties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
|
||||
MPImageContainer container;
|
||||
MPImageProperties byteBufferProperties =
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
|
||||
.setImageFormat(targetFormat)
|
||||
.build();
|
||||
if ((container = image.getContainer(byteBufferProperties)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
@ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
|
||||
.asReadOnlyBuffer();
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
|
||||
ByteBuffer byteBuffer =
|
||||
extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat)
|
||||
|
@ -98,85 +98,89 @@ public class ByteBufferExtractor {
|
|||
return byteBuffer;
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting ByteBuffer from an Image created by objects other than Bitmap or"
|
||||
"Extracting ByteBuffer from a MPImage created by objects other than Bitmap or"
|
||||
+ " Bytebuffer is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
/** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */
|
||||
/** A wrapper for a {@link ByteBuffer} and its {@link MPImageFormat}. */
|
||||
@AutoValue
|
||||
abstract static class Result {
|
||||
/** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||
/**
|
||||
* Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
|
||||
*/
|
||||
public abstract ByteBuffer buffer();
|
||||
|
||||
/** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */
|
||||
@ImageFormat
|
||||
/**
|
||||
* Gets the {@link MPImageFormat} in the result of {@link ByteBufferExtractor#extract(MPImage)}.
|
||||
*/
|
||||
@MPImageFormat
|
||||
public abstract int format();
|
||||
|
||||
static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) {
|
||||
static Result create(ByteBuffer buffer, @MPImageFormat int imageFormat) {
|
||||
return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}.
|
||||
* Extracts a {@link ByteBuffer} in any available {@code imageFormat} from a {@link MPImage}.
|
||||
*
|
||||
* <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy.
|
||||
*
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link Image}
|
||||
* @return the readonly {@link ByteBuffer} stored in {@link MPImage}
|
||||
* @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with
|
||||
* given {@code imageFormat}
|
||||
*/
|
||||
static Result extractInRecommendedFormat(Image image) {
|
||||
ImageContainer container;
|
||||
if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) {
|
||||
static Result extractInRecommendedFormat(MPImage image) {
|
||||
MPImageContainer container;
|
||||
if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
Bitmap bitmap = ((BitmapImageContainer) container).getBitmap();
|
||||
@ImageFormat int format = adviseImageFormat(bitmap);
|
||||
@MPImageFormat int format = adviseImageFormat(bitmap);
|
||||
Result result =
|
||||
Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format);
|
||||
|
||||
boolean unused =
|
||||
image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format()));
|
||||
return result;
|
||||
} else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return Result.create(
|
||||
byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(),
|
||||
byteBufferImageContainer.getImageFormat());
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer"
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bitmap or Bytebuffer"
|
||||
+ " is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
private static int adviseImageFormat(Bitmap bitmap) {
|
||||
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
} else {
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Extracting ByteBuffer from an Image created by a Bitmap in config %s is not"
|
||||
"Extracting ByteBuffer from a MPImage created by a Bitmap in config %s is not"
|
||||
+ " supported",
|
||||
bitmap.getConfig()));
|
||||
}
|
||||
}
|
||||
|
||||
private static ByteBuffer extractByteBufferFromBitmap(
|
||||
Bitmap bitmap, @ImageFormat int imageFormat) {
|
||||
Bitmap bitmap, @MPImageFormat int imageFormat) {
|
||||
if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) {
|
||||
throw new IllegalArgumentException(
|
||||
"Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not"
|
||||
"Extracting ByteBuffer from a MPImage created by a premultiplied Bitmap is not"
|
||||
+ " supported");
|
||||
}
|
||||
if (bitmap.getConfig() == Config.ARGB_8888) {
|
||||
if (imageFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||
if (imageFormat == MPImage.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount());
|
||||
bitmap.copyPixelsToBuffer(buffer);
|
||||
buffer.rewind();
|
||||
return buffer;
|
||||
} else if (imageFormat == Image.IMAGE_FORMAT_RGB) {
|
||||
} else if (imageFormat == MPImage.IMAGE_FORMAT_RGB) {
|
||||
// TODO: Try Use RGBA buffer to create RGB buffer which might be faster.
|
||||
int w = bitmap.getWidth();
|
||||
int h = bitmap.getHeight();
|
||||
|
@ -196,14 +200,14 @@ public class ByteBufferExtractor {
|
|||
}
|
||||
throw new IllegalArgumentException(
|
||||
String.format(
|
||||
"Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format"
|
||||
"Extracting ByteBuffer from a MPImage created by Bitmap and convert from %s to format"
|
||||
+ " %d is not supported",
|
||||
bitmap.getConfig(), imageFormat));
|
||||
}
|
||||
|
||||
private static ByteBuffer convertByteBuffer(
|
||||
ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) {
|
||||
if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer source, @MPImageFormat int sourceFormat, @MPImageFormat int targetFormat) {
|
||||
if (sourceFormat == MPImage.IMAGE_FORMAT_RGB && targetFormat == MPImage.IMAGE_FORMAT_RGBA) {
|
||||
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4);
|
||||
// Extend the buffer when the target is longer than the source. Use two cursors and sweep the
|
||||
// array reversely to convert in-place.
|
||||
|
@ -221,7 +225,8 @@ public class ByteBufferExtractor {
|
|||
target.put(array, 0, target.capacity());
|
||||
target.rewind();
|
||||
return target;
|
||||
} else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) {
|
||||
} else if (sourceFormat == MPImage.IMAGE_FORMAT_RGBA
|
||||
&& targetFormat == MPImage.IMAGE_FORMAT_RGB) {
|
||||
ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3);
|
||||
// Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the
|
||||
// array to convert in-place.
|
||||
|
|
|
@ -15,11 +15,11 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
/**
|
||||
* Builds a {@link Image} from a {@link ByteBuffer}.
|
||||
* Builds a {@link MPImage} from a {@link ByteBuffer}.
|
||||
*
|
||||
* <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link
|
||||
* ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it.
|
||||
|
@ -32,7 +32,7 @@ public class ByteBufferImageBuilder {
|
|||
private final ByteBuffer buffer;
|
||||
private final int width;
|
||||
private final int height;
|
||||
@ImageFormat private final int imageFormat;
|
||||
@MPImageFormat private final int imageFormat;
|
||||
|
||||
// Optional fields.
|
||||
private long timestamp;
|
||||
|
@ -49,7 +49,7 @@ public class ByteBufferImageBuilder {
|
|||
* @param imageFormat how the data encode the image.
|
||||
*/
|
||||
public ByteBufferImageBuilder(
|
||||
ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) {
|
||||
ByteBuffer byteBuffer, int width, int height, @MPImageFormat int imageFormat) {
|
||||
this.buffer = byteBuffer;
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
|
@ -58,14 +58,14 @@ public class ByteBufferImageBuilder {
|
|||
this.timestamp = 0;
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
ByteBufferImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,21 +15,19 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
class ByteBufferImageContainer implements ImageContainer {
|
||||
class ByteBufferImageContainer implements MPImageContainer {
|
||||
|
||||
private final ByteBuffer buffer;
|
||||
private final ImageProperties properties;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public ByteBufferImageContainer(
|
||||
ByteBuffer buffer,
|
||||
@ImageFormat int imageFormat) {
|
||||
public ByteBufferImageContainer(ByteBuffer buffer, @MPImageFormat int imageFormat) {
|
||||
this.buffer = buffer;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_BYTEBUFFER)
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_BYTEBUFFER)
|
||||
.setImageFormat(imageFormat)
|
||||
.build();
|
||||
}
|
||||
|
@ -39,14 +37,12 @@ class ByteBufferImageContainer implements ImageContainer {
|
|||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the image format.
|
||||
*/
|
||||
@ImageFormat
|
||||
/** Returns the image format. */
|
||||
@MPImageFormat
|
||||
public int getImageFormat() {
|
||||
return properties.getImageFormat();
|
||||
}
|
||||
|
|
|
@ -29,10 +29,10 @@ import java.util.Map.Entry;
|
|||
/**
|
||||
* The wrapper class for image objects.
|
||||
*
|
||||
* <p>{@link Image} is designed to be an immutable image container, which could be shared
|
||||
* <p>{@link MPImage} is designed to be an immutable image container, which could be shared
|
||||
* cross-platforms.
|
||||
*
|
||||
* <p>To construct an {@link Image}, use the provided builders:
|
||||
* <p>To construct a {@link MPImage}, use the provided builders:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link ByteBufferImageBuilder}
|
||||
|
@ -40,7 +40,7 @@ import java.util.Map.Entry;
|
|||
* <li>{@link MediaImageBuilder}
|
||||
* </ul>
|
||||
*
|
||||
* <p>{@link Image} uses reference counting to maintain internal storage. When it is created the
|
||||
* <p>{@link MPImage} uses reference counting to maintain internal storage. When it is created the
|
||||
* reference count is 1. Developer can call {@link #close()} to reduce reference count to release
|
||||
* internal storage earlier, otherwise Java garbage collection will release the storage eventually.
|
||||
*
|
||||
|
@ -53,7 +53,7 @@ import java.util.Map.Entry;
|
|||
* <li>{@link MediaImageExtractor}
|
||||
* </ul>
|
||||
*/
|
||||
public class Image implements Closeable {
|
||||
public class MPImage implements Closeable {
|
||||
|
||||
/** Specifies the image format of an image. */
|
||||
@IntDef({
|
||||
|
@ -69,7 +69,7 @@ public class Image implements Closeable {
|
|||
IMAGE_FORMAT_JPEG,
|
||||
})
|
||||
@Retention(RetentionPolicy.SOURCE)
|
||||
public @interface ImageFormat {}
|
||||
public @interface MPImageFormat {}
|
||||
|
||||
public static final int IMAGE_FORMAT_UNKNOWN = 0;
|
||||
public static final int IMAGE_FORMAT_RGBA = 1;
|
||||
|
@ -98,14 +98,14 @@ public class Image implements Closeable {
|
|||
public static final int STORAGE_TYPE_IMAGE_PROXY = 4;
|
||||
|
||||
/**
|
||||
* Returns a list of supported image properties for this {@link Image}.
|
||||
* Returns a list of supported image properties for this {@link MPImage}.
|
||||
*
|
||||
* <p>Currently {@link Image} only support single storage type so the size of return list will
|
||||
* <p>Currently {@link MPImage} only support single storage type so the size of return list will
|
||||
* always be 1.
|
||||
*
|
||||
* @see ImageProperties
|
||||
* @see MPImageProperties
|
||||
*/
|
||||
public List<ImageProperties> getContainedImageProperties() {
|
||||
public List<MPImageProperties> getContainedImageProperties() {
|
||||
return Collections.singletonList(getContainer().getImageProperties());
|
||||
}
|
||||
|
||||
|
@ -124,7 +124,7 @@ public class Image implements Closeable {
|
|||
return height;
|
||||
}
|
||||
|
||||
/** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */
|
||||
/** Acquires a reference on this {@link MPImage}. This will increase the reference count by 1. */
|
||||
private synchronized void acquire() {
|
||||
referenceCount += 1;
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ public class Image implements Closeable {
|
|||
/**
|
||||
* Removes a reference that was previously acquired or init.
|
||||
*
|
||||
* <p>When {@link Image} is created, it has 1 reference count.
|
||||
* <p>When {@link MPImage} is created, it has 1 reference count.
|
||||
*
|
||||
* <p>When the reference count becomes 0, it will release the resource under the hood.
|
||||
*/
|
||||
|
@ -141,24 +141,24 @@ public class Image implements Closeable {
|
|||
public synchronized void close() {
|
||||
referenceCount -= 1;
|
||||
if (referenceCount == 0) {
|
||||
for (ImageContainer imageContainer : containerMap.values()) {
|
||||
for (MPImageContainer imageContainer : containerMap.values()) {
|
||||
imageContainer.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Advanced API access for {@link Image}. */
|
||||
/** Advanced API access for {@link MPImage}. */
|
||||
static final class Internal {
|
||||
|
||||
/**
|
||||
* Acquires a reference on this {@link Image}. This will increase the reference count by 1.
|
||||
* Acquires a reference on this {@link MPImage}. This will increase the reference count by 1.
|
||||
*
|
||||
* <p>This method is more useful for image consumer to acquire a reference so image resource
|
||||
* will not be closed accidentally. As image creator, normal developer doesn't need to call this
|
||||
* method.
|
||||
*
|
||||
* <p>The reference count is 1 when {@link Image} is created. Developer can call {@link
|
||||
* #close()} to indicate it doesn't need this {@link Image} anymore.
|
||||
* <p>The reference count is 1 when {@link MPImage} is created. Developer can call {@link
|
||||
* #close()} to indicate it doesn't need this {@link MPImage} anymore.
|
||||
*
|
||||
* @see #close()
|
||||
*/
|
||||
|
@ -166,10 +166,10 @@ public class Image implements Closeable {
|
|||
image.acquire();
|
||||
}
|
||||
|
||||
private final Image image;
|
||||
private final MPImage image;
|
||||
|
||||
// Only Image creates the internal helper.
|
||||
private Internal(Image image) {
|
||||
// Only MPImage creates the internal helper.
|
||||
private Internal(MPImage image) {
|
||||
this.image = image;
|
||||
}
|
||||
}
|
||||
|
@ -179,15 +179,15 @@ public class Image implements Closeable {
|
|||
return new Internal(this);
|
||||
}
|
||||
|
||||
private final Map<ImageProperties, ImageContainer> containerMap;
|
||||
private final Map<MPImageProperties, MPImageContainer> containerMap;
|
||||
private final long timestamp;
|
||||
private final int width;
|
||||
private final int height;
|
||||
|
||||
private int referenceCount;
|
||||
|
||||
/** Constructs an {@link Image} with a built container. */
|
||||
Image(ImageContainer container, long timestamp, int width, int height) {
|
||||
/** Constructs a {@link MPImage} with a built container. */
|
||||
MPImage(MPImageContainer container, long timestamp, int width, int height) {
|
||||
this.containerMap = new HashMap<>();
|
||||
containerMap.put(container.getImageProperties(), container);
|
||||
this.timestamp = timestamp;
|
||||
|
@ -201,10 +201,10 @@ public class Image implements Closeable {
|
|||
*
|
||||
* @return the current container.
|
||||
*/
|
||||
ImageContainer getContainer() {
|
||||
MPImageContainer getContainer() {
|
||||
// According to the design, in the future we will support multiple containers in one image.
|
||||
// Currently just return the original container.
|
||||
// TODO: Cache multiple containers in Image.
|
||||
// TODO: Cache multiple containers in MPImage.
|
||||
return containerMap.values().iterator().next();
|
||||
}
|
||||
|
||||
|
@ -214,8 +214,8 @@ public class Image implements Closeable {
|
|||
* <p>If there are multiple containers with required {@code storageType}, returns the first one.
|
||||
*/
|
||||
@Nullable
|
||||
ImageContainer getContainer(@StorageType int storageType) {
|
||||
for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) {
|
||||
MPImageContainer getContainer(@StorageType int storageType) {
|
||||
for (Entry<MPImageProperties, MPImageContainer> entry : containerMap.entrySet()) {
|
||||
if (entry.getKey().getStorageType() == storageType) {
|
||||
return entry.getValue();
|
||||
}
|
||||
|
@ -225,13 +225,13 @@ public class Image implements Closeable {
|
|||
|
||||
/** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */
|
||||
@Nullable
|
||||
ImageContainer getContainer(ImageProperties imageProperties) {
|
||||
MPImageContainer getContainer(MPImageProperties imageProperties) {
|
||||
return containerMap.get(imageProperties);
|
||||
}
|
||||
|
||||
/** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */
|
||||
boolean addContainer(ImageContainer container) {
|
||||
ImageProperties imageProperties = container.getImageProperties();
|
||||
boolean addContainer(MPImageContainer container) {
|
||||
MPImageProperties imageProperties = container.getImageProperties();
|
||||
if (containerMap.containsKey(imageProperties)) {
|
||||
return false;
|
||||
}
|
|
@ -14,14 +14,14 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Lightweight abstraction for an object that can receive {@link Image} */
|
||||
public interface ImageConsumer {
|
||||
/** Lightweight abstraction for an object that can receive {@link MPImage} */
|
||||
public interface MPImageConsumer {
|
||||
|
||||
/**
|
||||
* Called when an {@link Image} is available.
|
||||
* Called when a {@link MPImage} is available.
|
||||
*
|
||||
* <p>The argument is only guaranteed to be available until this method returns. if you need to
|
||||
* extend its life time, acquire it, then release it when done.
|
||||
*/
|
||||
void onNewImage(Image image);
|
||||
void onNewMPImage(MPImage image);
|
||||
}
|
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Manages internal image data storage. The interface is package-private. */
|
||||
interface ImageContainer {
|
||||
interface MPImageContainer {
|
||||
/** Returns the properties of the contained image. */
|
||||
ImageProperties getImageProperties();
|
||||
MPImageProperties getImageProperties();
|
||||
|
||||
/** Close the image container and releases the image resource inside. */
|
||||
void close();
|
|
@ -14,9 +14,9 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
/** Lightweight abstraction for an object that produce {@link Image} */
|
||||
public interface ImageProducer {
|
||||
/** Lightweight abstraction for an object that produce {@link MPImage} */
|
||||
public interface MPImageProducer {
|
||||
|
||||
/** Sets the consumer that receives the {@link Image}. */
|
||||
void setImageConsumer(ImageConsumer imageConsumer);
|
||||
/** Sets the consumer that receives the {@link MPImage}. */
|
||||
void setMPImageConsumer(MPImageConsumer imageConsumer);
|
||||
}
|
|
@ -17,25 +17,25 @@ package com.google.mediapipe.framework.image;
|
|||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.auto.value.extension.memoized.Memoized;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.Image.StorageType;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.StorageType;
|
||||
|
||||
/** Groups a set of properties to describe how an image is stored. */
|
||||
@AutoValue
|
||||
public abstract class ImageProperties {
|
||||
public abstract class MPImageProperties {
|
||||
|
||||
/**
|
||||
* Gets the pixel format of the image.
|
||||
*
|
||||
* @see Image.ImageFormat
|
||||
* @see MPImage.MPImageFormat
|
||||
*/
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
public abstract int getImageFormat();
|
||||
|
||||
/**
|
||||
* Gets the storage type of the image.
|
||||
*
|
||||
* @see Image.StorageType
|
||||
* @see MPImage.StorageType
|
||||
*/
|
||||
@StorageType
|
||||
public abstract int getStorageType();
|
||||
|
@ -45,36 +45,36 @@ public abstract class ImageProperties {
|
|||
public abstract int hashCode();
|
||||
|
||||
/**
|
||||
* Creates a builder of {@link ImageProperties}.
|
||||
* Creates a builder of {@link MPImageProperties}.
|
||||
*
|
||||
* @see ImageProperties.Builder
|
||||
* @see MPImageProperties.Builder
|
||||
*/
|
||||
static Builder builder() {
|
||||
return new AutoValue_ImageProperties.Builder();
|
||||
return new AutoValue_MPImageProperties.Builder();
|
||||
}
|
||||
|
||||
/** Builds a {@link ImageProperties}. */
|
||||
/** Builds a {@link MPImageProperties}. */
|
||||
@AutoValue.Builder
|
||||
abstract static class Builder {
|
||||
|
||||
/**
|
||||
* Sets the {@link Image.ImageFormat}.
|
||||
* Sets the {@link MPImage.MPImageFormat}.
|
||||
*
|
||||
* @see ImageProperties#getImageFormat
|
||||
* @see MPImageProperties#getImageFormat
|
||||
*/
|
||||
abstract Builder setImageFormat(@ImageFormat int value);
|
||||
abstract Builder setImageFormat(@MPImageFormat int value);
|
||||
|
||||
/**
|
||||
* Sets the {@link Image.StorageType}.
|
||||
* Sets the {@link MPImage.StorageType}.
|
||||
*
|
||||
* @see ImageProperties#getStorageType
|
||||
* @see MPImageProperties#getStorageType
|
||||
*/
|
||||
abstract Builder setStorageType(@StorageType int value);
|
||||
|
||||
/** Builds the {@link ImageProperties}. */
|
||||
abstract ImageProperties build();
|
||||
/** Builds the {@link MPImageProperties}. */
|
||||
abstract MPImageProperties build();
|
||||
}
|
||||
|
||||
// Hide the constructor.
|
||||
ImageProperties() {}
|
||||
MPImageProperties() {}
|
||||
}
|
|
@ -15,11 +15,12 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* Builds {@link Image} from {@link android.media.Image}.
|
||||
* Builds {@link MPImage} from {@link android.media.Image}.
|
||||
*
|
||||
* <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify
|
||||
* content in it.
|
||||
|
@ -30,7 +31,7 @@ import androidx.annotation.RequiresApi;
|
|||
public class MediaImageBuilder {
|
||||
|
||||
// Mandatory fields.
|
||||
private final android.media.Image mediaImage;
|
||||
private final Image mediaImage;
|
||||
|
||||
// Optional fields.
|
||||
private long timestamp;
|
||||
|
@ -40,20 +41,20 @@ public class MediaImageBuilder {
|
|||
*
|
||||
* @param mediaImage image data object.
|
||||
*/
|
||||
public MediaImageBuilder(android.media.Image mediaImage) {
|
||||
public MediaImageBuilder(Image mediaImage) {
|
||||
this.mediaImage = mediaImage;
|
||||
this.timestamp = 0;
|
||||
}
|
||||
|
||||
/** Sets value for {@link Image#getTimestamp()}. */
|
||||
/** Sets value for {@link MPImage#getTimestamp()}. */
|
||||
MediaImageBuilder setTimestamp(long timestamp) {
|
||||
this.timestamp = timestamp;
|
||||
return this;
|
||||
}
|
||||
|
||||
/** Builds an {@link Image} instance. */
|
||||
public Image build() {
|
||||
return new Image(
|
||||
/** Builds a {@link MPImage} instance. */
|
||||
public MPImage build() {
|
||||
return new MPImage(
|
||||
new MediaImageContainer(mediaImage),
|
||||
timestamp,
|
||||
mediaImage.getWidth(),
|
||||
|
|
|
@ -15,33 +15,34 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build;
|
||||
import android.os.Build.VERSION;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
import com.google.mediapipe.framework.image.Image.ImageFormat;
|
||||
import com.google.mediapipe.framework.image.MPImage.MPImageFormat;
|
||||
|
||||
@RequiresApi(VERSION_CODES.KITKAT)
|
||||
class MediaImageContainer implements ImageContainer {
|
||||
class MediaImageContainer implements MPImageContainer {
|
||||
|
||||
private final android.media.Image mediaImage;
|
||||
private final ImageProperties properties;
|
||||
private final Image mediaImage;
|
||||
private final MPImageProperties properties;
|
||||
|
||||
public MediaImageContainer(android.media.Image mediaImage) {
|
||||
public MediaImageContainer(Image mediaImage) {
|
||||
this.mediaImage = mediaImage;
|
||||
this.properties =
|
||||
ImageProperties.builder()
|
||||
.setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE)
|
||||
MPImageProperties.builder()
|
||||
.setStorageType(MPImage.STORAGE_TYPE_MEDIA_IMAGE)
|
||||
.setImageFormat(convertFormatCode(mediaImage.getFormat()))
|
||||
.build();
|
||||
}
|
||||
|
||||
public android.media.Image getImage() {
|
||||
public Image getImage() {
|
||||
return mediaImage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageProperties getImageProperties() {
|
||||
public MPImageProperties getImageProperties() {
|
||||
return properties;
|
||||
}
|
||||
|
||||
|
@ -50,24 +51,24 @@ class MediaImageContainer implements ImageContainer {
|
|||
mediaImage.close();
|
||||
}
|
||||
|
||||
@ImageFormat
|
||||
@MPImageFormat
|
||||
static int convertFormatCode(int graphicsFormat) {
|
||||
// We only cover the format mentioned in
|
||||
// https://developer.android.com/reference/android/media/Image#getFormat()
|
||||
if (VERSION.SDK_INT >= Build.VERSION_CODES.M) {
|
||||
if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) {
|
||||
return Image.IMAGE_FORMAT_RGBA;
|
||||
return MPImage.IMAGE_FORMAT_RGBA;
|
||||
} else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) {
|
||||
return Image.IMAGE_FORMAT_RGB;
|
||||
return MPImage.IMAGE_FORMAT_RGB;
|
||||
}
|
||||
}
|
||||
switch (graphicsFormat) {
|
||||
case android.graphics.ImageFormat.JPEG:
|
||||
return Image.IMAGE_FORMAT_JPEG;
|
||||
return MPImage.IMAGE_FORMAT_JPEG;
|
||||
case android.graphics.ImageFormat.YUV_420_888:
|
||||
return Image.IMAGE_FORMAT_YUV_420_888;
|
||||
return MPImage.IMAGE_FORMAT_YUV_420_888;
|
||||
default:
|
||||
return Image.IMAGE_FORMAT_UNKNOWN;
|
||||
return MPImage.IMAGE_FORMAT_UNKNOWN;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,13 +15,14 @@ limitations under the License.
|
|||
|
||||
package com.google.mediapipe.framework.image;
|
||||
|
||||
import android.media.Image;
|
||||
import android.os.Build.VERSION_CODES;
|
||||
import androidx.annotation.RequiresApi;
|
||||
|
||||
/**
|
||||
* Utility for extracting {@link android.media.Image} from {@link Image}.
|
||||
* Utility for extracting {@link android.media.Image} from {@link MPImage}.
|
||||
*
|
||||
* <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE},
|
||||
* <p>Currently it only supports {@link MPImage} with {@link MPImage#STORAGE_TYPE_MEDIA_IMAGE},
|
||||
* otherwise {@link IllegalArgumentException} will be thrown.
|
||||
*/
|
||||
@RequiresApi(VERSION_CODES.KITKAT)
|
||||
|
@ -30,20 +31,20 @@ public class MediaImageExtractor {
|
|||
private MediaImageExtractor() {}
|
||||
|
||||
/**
|
||||
* Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for
|
||||
* {@link Image} that built from {@link MediaImageBuilder}.
|
||||
* Extracts a {@link android.media.Image} from a {@link MPImage}. Currently it only works for
|
||||
* {@link MPImage} that built from {@link MediaImageBuilder}.
|
||||
*
|
||||
* @param image the image to extract {@link android.media.Image} from.
|
||||
* @return {@link android.media.Image} that stored in {@link Image}.
|
||||
* @return {@link android.media.Image} that stored in {@link MPImage}.
|
||||
* @throws IllegalArgumentException if the extraction failed.
|
||||
*/
|
||||
public static android.media.Image extract(Image image) {
|
||||
ImageContainer container;
|
||||
if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
|
||||
public static Image extract(MPImage image) {
|
||||
MPImageContainer container;
|
||||
if ((container = image.getContainer(MPImage.STORAGE_TYPE_MEDIA_IMAGE)) != null) {
|
||||
return ((MediaImageContainer) container).getImage();
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"Extract Media Image from an Image created by objects other than Media Image"
|
||||
"Extract Media Image from a MPImage created by objects other than Media Image"
|
||||
+ " is not supported");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2019-2020 The MediaPipe Authors.
|
||||
# Copyright 2019-2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -89,10 +89,6 @@ def mediapipe_aar(
|
|||
calculators = calculators,
|
||||
)
|
||||
|
||||
_mediapipe_proto(
|
||||
name = name + "_proto",
|
||||
)
|
||||
|
||||
native.genrule(
|
||||
name = name + "_aar_manifest_generator",
|
||||
outs = ["AndroidManifest.xml"],
|
||||
|
@ -115,19 +111,10 @@ EOF
|
|||
"//mediapipe/java/com/google/mediapipe/components:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:java_src",
|
||||
"//mediapipe/java/com/google/mediapipe/glutil:java_src",
|
||||
"com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
"com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
"com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
"com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
"com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
"com/google/mediapipe/proto/CalculatorProto.java",
|
||||
] +
|
||||
] + mediapipe_java_proto_srcs() +
|
||||
select({
|
||||
"//conditions:default": [],
|
||||
"enable_stats_logging": [
|
||||
"com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
"com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
],
|
||||
"enable_stats_logging": mediapipe_logging_java_proto_srcs(),
|
||||
}),
|
||||
manifest = "AndroidManifest.xml",
|
||||
proguard_specs = ["//mediapipe/java/com/google/mediapipe/framework:proguard.pgcfg"],
|
||||
|
@ -177,93 +164,9 @@ EOF
|
|||
assets_dir = assets_dir,
|
||||
)
|
||||
|
||||
_aar_with_jni(name, name + "_android_lib")
|
||||
|
||||
def _mediapipe_proto(name):
|
||||
"""Generates MediaPipe java proto libraries.
|
||||
|
||||
Args:
|
||||
name: the name of the target.
|
||||
"""
|
||||
_proto_java_src_generator(
|
||||
name = "mediapipe_log_extension_proto",
|
||||
proto_src = "mediapipe/util/analytics/mediapipe_log_extension.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
srcs = ["//mediapipe/util/analytics:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "mediapipe_logging_enums_proto",
|
||||
proto_src = "mediapipe/util/analytics/mediapipe_logging_enums.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
srcs = ["//mediapipe/util/analytics:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "calculator_proto",
|
||||
proto_src = "mediapipe/framework/calculator.proto",
|
||||
java_lite_out = "com/google/mediapipe/proto/CalculatorProto.java",
|
||||
srcs = ["//mediapipe/framework:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "landmark_proto",
|
||||
proto_src = "mediapipe/framework/formats/landmark.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
srcs = ["//mediapipe/framework/formats:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "rasterization_proto",
|
||||
proto_src = "mediapipe/framework/formats/annotation/rasterization.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
srcs = ["//mediapipe/framework/formats/annotation:protos_src"],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "location_data_proto",
|
||||
proto_src = "mediapipe/framework/formats/location_data.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
"//mediapipe/framework/formats/annotation:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "detection_proto",
|
||||
proto_src = "mediapipe/framework/formats/detection.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
"//mediapipe/framework/formats/annotation:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
_proto_java_src_generator(
|
||||
name = "classification_proto",
|
||||
proto_src = "mediapipe/framework/formats/classification.proto",
|
||||
java_lite_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
srcs = [
|
||||
"//mediapipe/framework/formats:protos_src",
|
||||
],
|
||||
)
|
||||
|
||||
def _proto_java_src_generator(name, proto_src, java_lite_out, srcs = []):
|
||||
native.genrule(
|
||||
name = name + "_proto_java_src_generator",
|
||||
srcs = srcs + [
|
||||
"@com_google_protobuf//:lite_well_known_protos",
|
||||
],
|
||||
outs = [java_lite_out],
|
||||
cmd = "$(location @com_google_protobuf//:protoc) " +
|
||||
"--proto_path=. --proto_path=$(GENDIR) " +
|
||||
"--proto_path=$$(pwd)/external/com_google_protobuf/src " +
|
||||
"--java_out=lite:$(GENDIR) " + proto_src + " && " +
|
||||
"mv $(GENDIR)/" + java_lite_out + " $$(dirname $(location " + java_lite_out + "))",
|
||||
tools = [
|
||||
"@com_google_protobuf//:protoc",
|
||||
],
|
||||
mediapipe_build_aar_with_jni(
|
||||
name = name,
|
||||
android_library = name + "_android_lib",
|
||||
)
|
||||
|
||||
def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
|
||||
|
@ -303,7 +206,14 @@ def _mediapipe_jni(name, gen_libmediapipe, calculators = []):
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
def _aar_with_jni(name, android_library):
|
||||
def mediapipe_build_aar_with_jni(name, android_library):
|
||||
"""Builds MediaPipe AAR with jni.
|
||||
|
||||
Args:
|
||||
name: The bazel target name.
|
||||
android_library: the android library that contains jni.
|
||||
"""
|
||||
|
||||
# Generates dummy AndroidManifest.xml for dummy apk usage
|
||||
# (dummy apk is generated by <name>_dummy_app target below)
|
||||
native.genrule(
|
||||
|
@ -314,7 +224,7 @@ cat > $(OUTS) <<EOF
|
|||
<manifest
|
||||
xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="dummy.package.for.so">
|
||||
<uses-sdk android:minSdkVersion="21"/>
|
||||
<uses-sdk android:minSdkVersion="24"/>
|
||||
</manifest>
|
||||
EOF
|
||||
""",
|
||||
|
@ -341,7 +251,133 @@ chmod +w $(location :{}.aar)
|
|||
origdir=$$PWD
|
||||
cd $$(mktemp -d)
|
||||
unzip $$origdir/$(location :{}_dummy_app_unsigned.apk) "lib/*"
|
||||
find lib -name *_dummy_app.so -delete
|
||||
cp -r lib jni
|
||||
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
|
||||
""".format(android_library, name, name, name, name),
|
||||
)
|
||||
|
||||
def mediapipe_java_proto_src_extractor(target, src_out, name = ""):
|
||||
"""Extracts the generated MediaPipe java proto source code from the target.
|
||||
|
||||
Args:
|
||||
target: The java proto lite target to be built and extracted.
|
||||
src_out: The output java proto src code path.
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The output java proto src code path.
|
||||
"""
|
||||
|
||||
if not name:
|
||||
name = target.split(":")[-1] + "_proto_java_src_extractor"
|
||||
src_jar = target.replace("_java_proto_lite", "_proto-lite-src.jar").replace(":", "/").replace("//", "")
|
||||
native.genrule(
|
||||
name = name + "_proto_java_src_extractor",
|
||||
srcs = [target],
|
||||
outs = [src_out],
|
||||
cmd = "unzip $(GENDIR)/" + src_jar + " -d $(GENDIR) && mv $(GENDIR)/" +
|
||||
src_out + " $$(dirname $(location " + src_out + "))",
|
||||
)
|
||||
return src_out
|
||||
|
||||
def mediapipe_java_proto_srcs(name = ""):
|
||||
"""Extracts the generated MediaPipe framework java proto source code.
|
||||
|
||||
Args:
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The list of the extrated MediaPipe java proto source code.
|
||||
"""
|
||||
|
||||
proto_src_list = []
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:calculator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/CalculatorProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/CalculatorOptionsProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:stream_handler_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/StreamHandlerProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:packet_factory_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/PacketFactoryProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:packet_generator_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/PacketGeneratorProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:status_handler_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/StatusHandlerProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework:mediapipe_options_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeOptionsProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats/annotation:rasterization_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/annotation/proto/RasterizationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:classification_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/ClassificationProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:detection_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/DetectionProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:landmark_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LandmarkProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:location_data_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/LocationDataProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/framework/formats:rect_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
||||
def mediapipe_logging_java_proto_srcs(name = ""):
|
||||
"""Extracts the generated logging-related MediaPipe java proto source code.
|
||||
|
||||
Args:
|
||||
name: The optional bazel target name.
|
||||
|
||||
Returns:
|
||||
The list of the extrated MediaPipe logging-related java proto source code.
|
||||
"""
|
||||
|
||||
proto_src_list = []
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/util/analytics:mediapipe_log_extension_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeLoggingProto.java",
|
||||
))
|
||||
|
||||
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||
target = "//mediapipe/util/analytics:mediapipe_logging_enums_java_proto_lite",
|
||||
src_out = "com/google/mediapipe/proto/MediaPipeLoggingEnumsProto.java",
|
||||
))
|
||||
return proto_src_list
|
||||
|
|
22
mediapipe/model_maker/BUILD
Normal file
22
mediapipe/model_maker/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//mediapipe/model_maker/...",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/__init__.py
Normal file
13
mediapipe/model_maker/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
22
mediapipe/model_maker/python/BUILD
Normal file
22
mediapipe/model_maker/python/BUILD
Normal file
|
@ -0,0 +1,22 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = [
|
||||
"//mediapipe/model_maker/...",
|
||||
],
|
||||
)
|
13
mediapipe/model_maker/python/__init__.py
Normal file
13
mediapipe/model_maker/python/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
26
mediapipe/model_maker/python/core/BUILD
Normal file
26
mediapipe/model_maker/python/core/BUILD
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
py_library(
|
||||
name = "hyperparameters",
|
||||
srcs = ["hyperparameters.py"],
|
||||
)
|
13
mediapipe/model_maker/python/core/__init__.py
Normal file
13
mediapipe/model_maker/python/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
61
mediapipe/model_maker/python/core/data/BUILD
Normal file
61
mediapipe/model_maker/python/core/data/BUILD
Normal file
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library and test compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "data_util",
|
||||
srcs = ["data_util.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "data_util_test",
|
||||
srcs = ["data_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/data/testdata"],
|
||||
deps = [":data_util"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset",
|
||||
srcs = ["dataset.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "dataset_test",
|
||||
srcs = ["dataset_test.py"],
|
||||
deps = [
|
||||
":dataset",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "classification_dataset",
|
||||
srcs = ["classification_dataset.py"],
|
||||
deps = [":dataset"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "classification_dataset_test",
|
||||
srcs = ["classification_dataset_test.py"],
|
||||
deps = [":classification_dataset"],
|
||||
)
|
13
mediapipe/model_maker/python/core/data/__init__.py
Normal file
13
mediapipe/model_maker/python/core/data/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common classification dataset library."""
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
|
||||
|
||||
class ClassificationDataset(ds.Dataset):
|
||||
"""DataLoader for classification models."""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
|
||||
super().__init__(dataset, size)
|
||||
self._index_by_label = index_by_label
|
||||
|
||||
@property
|
||||
def num_classes(self: ds._DatasetT) -> int:
|
||||
return len(self._index_by_label)
|
||||
|
||||
@property
|
||||
def index_by_label(self: ds._DatasetT) -> Any:
|
||||
return self._index_by_label
|
||||
|
||||
def split(self: ds._DatasetT,
|
||||
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||
"""Splits dataset into two sub-datasets with the given fraction.
|
||||
|
||||
Primarily used for splitting the data set into training and testing sets.
|
||||
|
||||
Args:
|
||||
fraction: float, demonstrates the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction, self._index_by_label)
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
_DatasetT = TypeVar(
|
||||
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
|
||||
|
||||
|
||||
class ClassificationDatasetTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
|
||||
class MagicClassificationDataset(
|
||||
classification_dataset.ClassificationDataset):
|
||||
"""A mock classification dataset class for testing purpose.
|
||||
|
||||
Attributes:
|
||||
value: A value variable stored by the mock dataset class for testing.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
index_by_label: Any, value: Any):
|
||||
super().__init__(
|
||||
dataset=dataset, size=size, index_by_label=index_by_label)
|
||||
self.value = value
|
||||
|
||||
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
return self._split(fraction, self.index_by_label, self.value)
|
||||
|
||||
# Some dummy inputs.
|
||||
magic_value = 42
|
||||
num_classes = 2
|
||||
index_by_label = (False, True)
|
||||
|
||||
# Create data loader from sample data.
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = MagicClassificationDataset(
|
||||
dataset=ds,
|
||||
size=len(ds),
|
||||
index_by_label=index_by_label,
|
||||
value=magic_value)
|
||||
|
||||
# Train/Test data split.
|
||||
fraction = .25
|
||||
train_data, test_data = data.split(fraction=fraction)
|
||||
|
||||
# `split` should return instances of child DataLoader.
|
||||
self.assertIsInstance(train_data, MagicClassificationDataset)
|
||||
self.assertIsInstance(test_data, MagicClassificationDataset)
|
||||
|
||||
# Make sure number of entries are right.
|
||||
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
||||
self.assertLen(train_data, fraction * len(ds))
|
||||
self.assertLen(test_data, len(ds) - len(train_data))
|
||||
|
||||
# Make sure attributes propagated correctly.
|
||||
self.assertEqual(train_data.num_classes, num_classes)
|
||||
self.assertEqual(test_data.index_by_label, index_by_label)
|
||||
self.assertEqual(train_data.value, magic_value)
|
||||
self.assertEqual(test_data.value, magic_value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
35
mediapipe/model_maker/python/core/data/data_util.py
Normal file
35
mediapipe/model_maker/python/core/data/data_util.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Data utility library."""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
"""Loads an image as an RGB numpy array.
|
||||
|
||||
Args:
|
||||
path: input image file absolute path.
|
||||
|
||||
Returns:
|
||||
An RGB image in numpy.ndarray.
|
||||
"""
|
||||
tf.compat.v1.logging.info('Loading RGB image %s', path)
|
||||
# TODO Replace the OpenCV image load and conversion library by
|
||||
# MediaPipe image utility library once it is ready.
|
||||
image = cv2.imread(path)
|
||||
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
44
mediapipe/model_maker/python/core/data/data_util_test.py
Normal file
44
mediapipe/model_maker/python/core/data/data_util_test.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
# Dependency imports
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import data_util
|
||||
|
||||
_WORKSPACE = "mediapipe"
|
||||
_TEST_DATA_DIR = os.path.join(
|
||||
_WORKSPACE, 'mediapipe/model_maker/python/core/data/testdata')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class DataUtilTest(tf.test.TestCase):
|
||||
|
||||
def test_load_rgb_image(self):
|
||||
image_path = os.path.join(FLAGS.test_srcdir, _TEST_DATA_DIR, 'test.jpg')
|
||||
image_data = data_util.load_image(image_path)
|
||||
self.assertEqual(image_data.shape, (5184, 3456, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
164
mediapipe/model_maker/python/core/data/dataset.py
Normal file
164
mediapipe/model_maker/python/core/data/dataset.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common dataset for model training and evaluation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
||||
_DatasetT = TypeVar('_DatasetT', bound='Dataset')
|
||||
|
||||
|
||||
class Dataset(object):
|
||||
"""A generic dataset class for loading model training and evaluation dataset.
|
||||
|
||||
For each ML task, such as image classification, text classification etc., a
|
||||
subclass can be derived from this class to provide task-specific data loading
|
||||
utilities.
|
||||
"""
|
||||
|
||||
def __init__(self, tf_dataset: tf.data.Dataset, size: Optional[int] = None):
|
||||
"""Initializes Dataset class.
|
||||
|
||||
To build dataset from raw data, consider using the task specific utilities,
|
||||
e.g. from_folder().
|
||||
|
||||
Args:
|
||||
tf_dataset: A tf.data.Dataset object that contains a potentially large set
|
||||
of elements, where each element is a pair of (input_data, target). The
|
||||
`input_data` means the raw input data, like an image, a text etc., while
|
||||
the `target` means the ground truth of the raw input data, e.g. the
|
||||
classification label of the image etc.
|
||||
size: The size of the dataset. tf.data.Dataset donesn't support a function
|
||||
to get the length directly since it's lazy-loaded and may be infinite.
|
||||
"""
|
||||
self._dataset = tf_dataset
|
||||
self._size = size
|
||||
|
||||
@property
|
||||
def size(self) -> Optional[int]:
|
||||
"""Returns the size of the dataset.
|
||||
|
||||
Note that this function may return None becuase the exact size of the
|
||||
dataset isn't a necessary parameter to create an instance of this class,
|
||||
and tf.data.Dataset donesn't support a function to get the length directly
|
||||
since it's lazy-loaded and may be infinite.
|
||||
In most cases, however, when an instance of this class is created by helper
|
||||
functions like 'from_folder', the size of the dataset will be preprocessed,
|
||||
and this function can return an int representing the size of the dataset.
|
||||
"""
|
||||
return self._size
|
||||
|
||||
def gen_tf_dataset(self,
|
||||
batch_size: int = 1,
|
||||
is_training: bool = False,
|
||||
shuffle: bool = False,
|
||||
preprocess: Optional[Callable[..., bool]] = None,
|
||||
drop_remainder: bool = False) -> tf.data.Dataset:
|
||||
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||
|
||||
Args:
|
||||
batch_size: An integer, the returned dataset will be batched by this size.
|
||||
is_training: A boolean, when True, the returned dataset will be optionally
|
||||
shuffled and repeated as an endless dataset.
|
||||
shuffle: A boolean, when True, the returned dataset will be shuffled to
|
||||
create randomness during model training.
|
||||
preprocess: A function taking three arguments in order, feature, label and
|
||||
boolean is_training.
|
||||
drop_remainder: boolean, whether the finaly batch drops remainder.
|
||||
|
||||
Returns:
|
||||
A TF dataset ready to be consumed by Keras model.
|
||||
"""
|
||||
dataset = self._dataset
|
||||
|
||||
if preprocess:
|
||||
preprocess = functools.partial(preprocess, is_training=is_training)
|
||||
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
if is_training:
|
||||
if shuffle:
|
||||
# Shuffle size should be bigger than the batch_size. Otherwise it's only
|
||||
# shuffling within the batch, which equals to not having shuffle.
|
||||
buffer_size = 3 * batch_size
|
||||
# But since we are doing shuffle before repeat, it doesn't make sense to
|
||||
# shuffle more than total available entries.
|
||||
# TODO: Investigate if shuffling before / after repeat
|
||||
# dataset can get a better performance?
|
||||
# Shuffle after repeat will give a more randomized dataset and mix the
|
||||
# epoch boundary: https://www.tensorflow.org/guide/data
|
||||
if self._size:
|
||||
buffer_size = min(self._size, buffer_size)
|
||||
dataset = dataset.shuffle(buffer_size=buffer_size)
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
|
||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||
# TODO: Consider converting dataset to distributed dataset
|
||||
# here.
|
||||
return dataset
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of element of the dataset."""
|
||||
if self._size is not None:
|
||||
return self._size
|
||||
else:
|
||||
return len(self._dataset)
|
||||
|
||||
def split(self: _DatasetT, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
"""Splits dataset into two sub-datasets with the given fraction.
|
||||
|
||||
Primarily used for splitting the data set into training and testing sets.
|
||||
|
||||
Args:
|
||||
fraction: A float value defines the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction)
|
||||
|
||||
def _split(self: _DatasetT, fraction: float,
|
||||
*args) -> Tuple[_DatasetT, _DatasetT]:
|
||||
"""Implementation for `split` method and returns sub-class instances.
|
||||
|
||||
Child DataLoader classes, if requires additional constructor arguments,
|
||||
should implement their own `split` method by calling `_split` with all
|
||||
arguments to the constructor.
|
||||
|
||||
Args:
|
||||
fraction: A float value defines the fraction of the first returned
|
||||
subdataset in the original data.
|
||||
*args: additional arguments passed to the sub-class constructor.
|
||||
|
||||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
assert (fraction > 0 and fraction < 1)
|
||||
|
||||
dataset = self._dataset
|
||||
|
||||
train_size = int(self._size * fraction)
|
||||
trainset = self.__class__(dataset.take(train_size), train_size, *args)
|
||||
|
||||
test_size = self._size - train_size
|
||||
testset = self.__class__(dataset.skip(train_size), test_size, *args)
|
||||
|
||||
return trainset, testset
|
78
mediapipe/model_maker/python/core/data/dataset_test.py
Normal file
78
mediapipe/model_maker/python/core/data/dataset_test.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset as ds
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
|
||||
|
||||
class DatasetTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]])
|
||||
data = ds.Dataset(dataset, 4)
|
||||
train_data, test_data = data.split(0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
self.assertIsInstance(train_data, ds.Dataset)
|
||||
self.assertIsInstance(test_data, ds.Dataset)
|
||||
for i, elem in enumerate(train_data.gen_tf_dataset()):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
for i, elem in enumerate(test_data.gen_tf_dataset()):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||
|
||||
def test_len(self):
|
||||
size = 4
|
||||
dataset = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0],
|
||||
[1, 0]])
|
||||
data = ds.Dataset(dataset, size)
|
||||
self.assertLen(data, size)
|
||||
|
||||
def test_gen_tf_dataset(self):
|
||||
input_dim = 8
|
||||
data = test_util.create_dataset(
|
||||
data_size=2, input_shape=[input_dim], num_classes=2)
|
||||
|
||||
dataset = data.gen_tf_dataset()
|
||||
self.assertLen(dataset, 2)
|
||||
for (feature, label) in dataset:
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([1, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([1])).all())
|
||||
|
||||
dataset2 = data.gen_tf_dataset(batch_size=2)
|
||||
self.assertLen(dataset2, 1)
|
||||
for (feature, label) in dataset2:
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||
|
||||
dataset3 = data.gen_tf_dataset(batch_size=2, is_training=True, shuffle=True)
|
||||
self.assertEqual(dataset3.cardinality(), 1)
|
||||
for (feature, label) in dataset3.take(10):
|
||||
self.assertTrue((tf.shape(feature).numpy() == np.array([2, 8])).all())
|
||||
self.assertTrue((tf.shape(label).numpy() == np.array([2])).all())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
30
mediapipe/model_maker/python/core/data/testdata/BUILD
vendored
Normal file
30
mediapipe/model_maker/python/core/data/testdata/BUILD
vendored
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//mediapipe/framework/tool:mediapipe_files.bzl",
|
||||
"mediapipe_files",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe/model_maker/python/core/data:__subpackages__"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
mediapipe_files(srcs = ["test.jpg"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = ["test.jpg"],
|
||||
)
|
68
mediapipe/model_maker/python/core/hyperparameters.py
Normal file
68
mediapipe/model_maker/python/core/hyperparameters.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hyperparameters for training models. Shared across tasks."""
|
||||
|
||||
import dataclasses
|
||||
import tempfile
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# TODO: Integrate this class into ImageClassifier and other tasks.
|
||||
@dataclasses.dataclass
|
||||
class BaseHParams:
|
||||
"""Hyperparameters used for training models.
|
||||
|
||||
A common set of hyperparameters shared by the training jobs of all model
|
||||
maker tasks.
|
||||
|
||||
Attributes:
|
||||
learning_rate: The learning rate to use for gradient descent training.
|
||||
batch_size: Batch size for training.
|
||||
epochs: Number of training iterations over the dataset.
|
||||
steps_per_epoch: An optional integer indicate the number of training steps
|
||||
per epoch. If not set, the training pipeline calculates the default steps
|
||||
per epoch as the training dataset size devided by batch size.
|
||||
shuffle: True if the dataset is shuffled before training.
|
||||
export_dir: The location of the model checkpoint files.
|
||||
distribution_strategy: A string specifying which Distribution Strategy to
|
||||
use. Accepted values are 'off', 'one_device', 'mirrored',
|
||||
'parameter_server', 'multi_worker_mirrored', and 'tpu' -- case
|
||||
insensitive. 'off' means not to use Distribution Strategy; 'tpu' means to
|
||||
use TPUStrategy using `tpu_address`. See the tf.distribute.Strategy
|
||||
documentation for more details:
|
||||
https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy.
|
||||
num_gpus: How many GPUs to use at each worker with the
|
||||
DistributionStrategies API. The default is -1, which means utilize all
|
||||
available GPUs.
|
||||
tpu: The Cloud TPU to use for training. This should be either the name used
|
||||
when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.
|
||||
"""
|
||||
|
||||
# Parameters for train configuration
|
||||
learning_rate: float
|
||||
batch_size: int
|
||||
epochs: int
|
||||
steps_per_epoch: Optional[int] = None
|
||||
|
||||
# Dataset-related parameters
|
||||
shuffle: bool = False
|
||||
|
||||
# Parameters for model / checkpoint files
|
||||
export_dir: str = tempfile.mkdtemp()
|
||||
|
||||
# Parameters for hardware acceleration
|
||||
distribution_strategy: str = 'off'
|
||||
num_gpus: int = -1 # default value of -1 means use all available GPUs
|
||||
tpu: str = ''
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user