Merge branch 'google:master' into image-segmenter-python-impl
This commit is contained in:
commit
1748663a5a
|
@ -1410,3 +1410,45 @@ cc_library(
|
|||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "bypass_calculator_proto",
|
||||
srcs = ["bypass_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bypass_calculator",
|
||||
srcs = ["bypass_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bypass_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "bypass_calculator_test",
|
||||
srcs = ["bypass_calculator_test.cc"],
|
||||
deps = [
|
||||
":bypass_calculator",
|
||||
":pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:switch_container",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
|
161
mediapipe/calculators/core/bypass_calculator.cc
Normal file
161
mediapipe/calculators/core/bypass_calculator.cc
Normal file
|
@ -0,0 +1,161 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/calculators/core/bypass_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/collection_item_id.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using mediapipe::BypassCalculatorOptions;
|
||||
|
||||
// Defines a "bypass" channel to use in place of a disabled feature subgraph.
|
||||
// By default, all inputs are discarded and all outputs are ignored.
|
||||
// Certain input streams can be passed to corresponding output streams
|
||||
// by specifying them in "pass_input_stream" and "pass_output_stream" options.
|
||||
// All output streams are updated with timestamp bounds indicating completed
|
||||
// output.
|
||||
//
|
||||
// Note that this calculator is designed for use as a contained_node in a
|
||||
// SwitchContainer. For this reason, any input and output tags are accepted,
|
||||
// and stream semantics are specified through BypassCalculatorOptions.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "BypassCalculator"
|
||||
// input_stream: "APPEARANCES:appearances_post_facenet"
|
||||
// input_stream: "VIDEO:video_frame"
|
||||
// input_stream: "FEATURE_CONFIG:feature_config"
|
||||
// input_stream: "ENABLE:gaze_enabled"
|
||||
// output_stream: "APPEARANCES:analyzed_appearances"
|
||||
// output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
// node_options: {
|
||||
// [type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||
// pass_input_stream: "APPEARANCES"
|
||||
// pass_output_stream: "APPEARANCES"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
class BypassCalculator : public Node {
|
||||
public:
|
||||
static constexpr mediapipe::api2::Input<int>::Optional kNotNeeded{"N_N_"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kNotNeeded);
|
||||
using IdMap = std::map<CollectionItemId, CollectionItemId>;
|
||||
|
||||
// Returns the map of passthrough input and output stream ids.
|
||||
static absl::StatusOr<IdMap> GetPassMap(
|
||||
const BypassCalculatorOptions& options, const tool::TagMap& input_map,
|
||||
const tool::TagMap& output_map) {
|
||||
IdMap result;
|
||||
auto& input_streams = options.pass_input_stream();
|
||||
auto& output_streams = options.pass_output_stream();
|
||||
int size = std::min(input_streams.size(), output_streams.size());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
std::pair<std::string, int> in_tag, out_tag;
|
||||
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_input_stream(i),
|
||||
&in_tag.first, &in_tag.second));
|
||||
MP_RETURN_IF_ERROR(tool::ParseTagIndex(options.pass_output_stream(i),
|
||||
&out_tag.first, &out_tag.second));
|
||||
auto input_id = input_map.GetId(in_tag.first, in_tag.second);
|
||||
auto output_id = output_map.GetId(out_tag.first, out_tag.second);
|
||||
result[input_id] = output_id;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Identifies all specified streams as "Any" packet type.
|
||||
// Identifies passthrough streams as "Same" packet type.
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
auto options = cc->Options<BypassCalculatorOptions>();
|
||||
RET_CHECK_EQ(options.pass_input_stream().size(),
|
||||
options.pass_output_stream().size());
|
||||
ASSIGN_OR_RETURN(
|
||||
auto pass_streams,
|
||||
GetPassMap(options, *cc->Inputs().TagMap(), *cc->Outputs().TagMap()));
|
||||
std::set<CollectionItemId> pass_out;
|
||||
for (auto entry : pass_streams) {
|
||||
pass_out.insert(entry.second);
|
||||
cc->Inputs().Get(entry.first).SetAny();
|
||||
cc->Outputs().Get(entry.second).SetSameAs(&cc->Inputs().Get(entry.first));
|
||||
}
|
||||
for (auto id = cc->Inputs().BeginId(); id != cc->Inputs().EndId(); ++id) {
|
||||
if (pass_streams.count(id) == 0) {
|
||||
cc->Inputs().Get(id).SetAny();
|
||||
}
|
||||
}
|
||||
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
|
||||
if (pass_out.count(id) == 0) {
|
||||
cc->Outputs().Get(id).SetAny();
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Saves the map of passthrough input and output stream ids.
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
auto options = cc->Options<BypassCalculatorOptions>();
|
||||
ASSIGN_OR_RETURN(pass_streams_, GetPassMap(options, *cc->Inputs().TagMap(),
|
||||
*cc->Outputs().TagMap()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Copies packets between passthrough input and output streams.
|
||||
// Updates timestamp bounds on all output streams.
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
std::set<CollectionItemId> pass_out;
|
||||
for (auto entry : pass_streams_) {
|
||||
pass_out.insert(entry.second);
|
||||
auto& packet = cc->Inputs().Get(entry.first).Value();
|
||||
if (packet.Timestamp() == cc->InputTimestamp()) {
|
||||
cc->Outputs().Get(entry.first).AddPacket(packet);
|
||||
}
|
||||
}
|
||||
Timestamp bound = cc->InputTimestamp().NextAllowedInStream();
|
||||
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
|
||||
if (pass_out.count(id) == 0) {
|
||||
cc->Outputs().Get(id).SetNextTimestampBound(
|
||||
std::max(cc->Outputs().Get(id).NextTimestampBound(), bound));
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Close all output streams.
|
||||
absl::Status Close(CalculatorContext* cc) override {
|
||||
for (auto id = cc->Outputs().BeginId(); id != cc->Outputs().EndId(); ++id) {
|
||||
cc->Outputs().Get(id).Close();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
IdMap pass_streams_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(BypassCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
31
mediapipe/calculators/core/bypass_calculator.proto
Normal file
31
mediapipe/calculators/core/bypass_calculator.proto
Normal file
|
@ -0,0 +1,31 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message BypassCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional BypassCalculatorOptions ext = 481259677;
|
||||
}
|
||||
|
||||
// Names an input stream or streams to pass through, by "TAG:index".
|
||||
repeated string pass_input_stream = 1;
|
||||
|
||||
// Names an output stream or streams to pass through, by "TAG:index".
|
||||
repeated string pass_output_stream = 2;
|
||||
}
|
302
mediapipe/calculators/core/bypass_calculator_test.cc
Normal file
302
mediapipe/calculators/core/bypass_calculator_test.cc
Normal file
|
@ -0,0 +1,302 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
// A graph with using a BypassCalculator to pass through and ignore
|
||||
// most of its inputs and outputs.
|
||||
constexpr char kTestGraphConfig1[] = R"pb(
|
||||
type: "AppearancesPassThroughSubgraph"
|
||||
input_stream: "APPEARANCES:appearances"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
output_stream: "APPEARANCES:passthrough_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:passthrough_federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "BypassCalculator"
|
||||
input_stream: "PASS:appearances"
|
||||
input_stream: "TRUNCATE:0:video_frame"
|
||||
input_stream: "TRUNCATE:1:feature_config"
|
||||
output_stream: "PASS:passthrough_appearances"
|
||||
output_stream: "TRUNCATE:passthrough_federated_gaze_output"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||
pass_input_stream: "PASS"
|
||||
pass_output_stream: "PASS"
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// A graph with using AppearancesPassThroughSubgraph as a do-nothing channel
|
||||
// for input frames and appearances.
|
||||
constexpr char kTestGraphConfig2[] = R"pb(
|
||||
input_stream: "VIDEO_FULL_RES:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "GAZE_ENABLED:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "SwitchContainer"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "ENABLE:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
contained_node: { calculator: "AppearancesPassThroughSubgraph" }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// A graph with using BypassCalculator as a do-nothing channel
|
||||
// for input frames and appearances.
|
||||
constexpr char kTestGraphConfig3[] = R"pb(
|
||||
input_stream: "VIDEO_FULL_RES:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "GAZE_ENABLED:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "SwitchContainer"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "ENABLE:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
contained_node: {
|
||||
calculator: "BypassCalculator"
|
||||
node_options: {
|
||||
[type.googleapis.com/mediapipe.BypassCalculatorOptions] {
|
||||
pass_input_stream: "APPEARANCES"
|
||||
pass_output_stream: "APPEARANCES"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// A graph with using BypassCalculator as a disabled-gate
|
||||
// for input frames and appearances.
|
||||
constexpr char kTestGraphConfig4[] = R"pb(
|
||||
input_stream: "VIDEO_FULL_RES:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
input_stream: "GAZE_ENABLED:gaze_enabled"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEDERATED_GAZE_OUTPUT:federated_gaze_output"
|
||||
|
||||
node {
|
||||
calculator: "SwitchContainer"
|
||||
input_stream: "ENABLE:gaze_enabled"
|
||||
input_stream: "VIDEO:video_frame"
|
||||
input_stream: "APPEARANCES:input_appearances"
|
||||
input_stream: "FEATURE_CONFIG:feature_config"
|
||||
output_stream: "VIDEO:video_frame_out"
|
||||
output_stream: "APPEARANCES:analyzed_appearances"
|
||||
output_stream: "FEATURE_CONFIG:feature_config_out"
|
||||
options {
|
||||
[mediapipe.SwitchContainerOptions.ext] {
|
||||
contained_node: { calculator: "BypassCalculator" }
|
||||
contained_node: { calculator: "PassThroughCalculator" }
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb";
|
||||
|
||||
// Reports packet timestamp and string contents, or "<empty>"".
|
||||
std::string DebugString(Packet p) {
|
||||
return absl::StrCat(p.Timestamp().DebugString(), ":",
|
||||
p.IsEmpty() ? "<empty>" : p.Get<std::string>());
|
||||
}
|
||||
|
||||
// Shows a bypass subgraph that passes through one stream.
|
||||
TEST(BypassCalculatorTest, SubgraphChannel) {
|
||||
CalculatorGraphConfig config_1 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig1);
|
||||
CalculatorGraphConfig config_2 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig2);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize({config_1, config_2}, {}));
|
||||
|
||||
std::vector<std::string> analyzed_appearances;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"analyzed_appearances",
|
||||
[&](const Packet& p) {
|
||||
analyzed_appearances.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
std::vector<std::string> federated_gaze_output;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"federated_gaze_output",
|
||||
[&](const Packet& p) {
|
||||
federated_gaze_output.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
|
||||
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Shows a BypassCalculator that passes through one stream.
|
||||
TEST(BypassCalculatorTest, CalculatorChannel) {
|
||||
CalculatorGraphConfig config_3 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig3);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
|
||||
|
||||
std::vector<std::string> analyzed_appearances;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"analyzed_appearances",
|
||||
[&](const Packet& p) {
|
||||
analyzed_appearances.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
std::vector<std::string> federated_gaze_output;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"federated_gaze_output",
|
||||
[&](const Packet& p) {
|
||||
federated_gaze_output.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:a1"));
|
||||
EXPECT_THAT(federated_gaze_output, testing::ElementsAre("200:<empty>"));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Shows a BypassCalculator that discards all inputs when ENABLED is false.
|
||||
TEST(BypassCalculatorTest, GatedChannel) {
|
||||
CalculatorGraphConfig config_3 =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(kTestGraphConfig4);
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize({config_3}, {}));
|
||||
|
||||
std::vector<std::string> analyzed_appearances;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"analyzed_appearances",
|
||||
[&](const Packet& p) {
|
||||
analyzed_appearances.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
std::vector<std::string> video_frame;
|
||||
MP_ASSERT_OK(graph.ObserveOutputStream(
|
||||
"video_frame_out",
|
||||
[&](const Packet& p) {
|
||||
video_frame.push_back(DebugString(p));
|
||||
return absl::OkStatus();
|
||||
},
|
||||
true));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// Close the gate.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"gaze_enabled", MakePacket<bool>(false).At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Send packets at timestamp 200.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f1").At(Timestamp(200))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Only timestamps arrive from the BypassCalculator.
|
||||
EXPECT_THAT(analyzed_appearances, testing::ElementsAre("200:<empty>"));
|
||||
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>"));
|
||||
|
||||
// Open the gate.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"gaze_enabled", MakePacket<bool>(true).At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Send packets at timestamp 300.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"input_appearances", MakePacket<std::string>("a2").At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"video_frame", MakePacket<std::string>("v2").At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"feature_config", MakePacket<std::string>("f2").At(Timestamp(300))));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Packets arrive from the PassThroughCalculator.
|
||||
EXPECT_THAT(analyzed_appearances,
|
||||
testing::ElementsAre("200:<empty>", "300:a2"));
|
||||
EXPECT_THAT(video_frame, testing::ElementsAre("200:<empty>", "300:v2"));
|
||||
|
||||
MP_ASSERT_OK(graph.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
|
@ -16,6 +16,9 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe;
|
||||
|
||||
option java_package = "com.google.mediapipe.calculator.proto";
|
||||
option java_outer_classname = "RotationModeProto";
|
||||
|
||||
// Counterclockwise rotation.
|
||||
message RotationMode {
|
||||
enum Mode {
|
||||
|
|
|
@ -253,6 +253,60 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_to_tensor_calculator",
|
||||
srcs = ["text_to_tensor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "text_to_tensor_calculator_test",
|
||||
srcs = ["text_to_tensor_calculator_test.cc"],
|
||||
deps = [
|
||||
":text_to_tensor_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_graph",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "universal_sentence_encoder_preprocessor_calculator",
|
||||
srcs = ["universal_sentence_encoder_preprocessor_calculator.cc"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "inference_calculator_proto",
|
||||
srcs = ["inference_calculator.proto"],
|
||||
|
|
|
@ -270,10 +270,10 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
const auto& output_shape = output_tensor.shape();
|
||||
MP_RETURN_IF_ERROR(ValidateTensorShape(output_shape));
|
||||
|
@ -281,12 +281,13 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext(
|
||||
[this, &output_tensor, &input, &roi, &output_shape, range_min,
|
||||
range_max, tensor_buffer_offset]() -> absl::Status {
|
||||
constexpr int kRgbaNumChannels = 4;
|
||||
const int input_num_channels = input.channels();
|
||||
auto source_texture = gl_helper_.CreateSourceTexture(input);
|
||||
tflite::gpu::gl::GlTexture input_texture(
|
||||
GL_TEXTURE_2D, source_texture.name(), GL_RGBA,
|
||||
GL_TEXTURE_2D, source_texture.name(),
|
||||
input_num_channels == 4 ? GL_RGB : GL_RGBA,
|
||||
source_texture.width() * source_texture.height() *
|
||||
kRgbaNumChannels * sizeof(uint8_t),
|
||||
input_num_channels * sizeof(uint8_t),
|
||||
/*layer=*/0,
|
||||
/*owned=*/false);
|
||||
|
||||
|
|
|
@ -174,10 +174,10 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
Tensor& output_tensor) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGB24) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
"Unsupported format: ", static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
// TODO: support tensor_buffer_offset > 0 scenario.
|
||||
RET_CHECK_EQ(tensor_buffer_offset, 0)
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/sink.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::testing::ElementsAreArray;
|
||||
|
||||
constexpr int kMaxSeqLen = 256;
|
||||
constexpr char kTestModelPath[] =
|
||||
"mediapipe/tasks/testdata/text/"
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||
|
||||
absl::StatusOr<std::vector<int>> RunRegexPreprocessorCalculator(
|
||||
absl::string_view text) {
|
||||
auto graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(absl::Substitute(
|
||||
R"pb(
|
||||
input_stream: "text"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "RegexPreprocessorCalculator"
|
||||
input_stream: "TEXT:text"
|
||||
input_side_packet: "METADATA_EXTRACTOR:metadata_extractor"
|
||||
output_stream: "TENSORS:tensors"
|
||||
options {
|
||||
[mediapipe.RegexPreprocessorCalculatorOptions.ext] {
|
||||
max_seq_len: $0
|
||||
}
|
||||
}
|
||||
}
|
||||
)pb",
|
||||
kMaxSeqLen));
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensors", &graph_config, &output_packets);
|
||||
|
||||
std::string model_buffer = tasks::core::LoadBinaryContent(kTestModelPath);
|
||||
ASSIGN_OR_RETURN(std::unique_ptr<ModelMetadataExtractor> metadata_extractor,
|
||||
ModelMetadataExtractor::CreateFromModelBuffer(
|
||||
model_buffer.data(), model_buffer.size()));
|
||||
// Run the graph.
|
||||
CalculatorGraph graph;
|
||||
MP_RETURN_IF_ERROR(graph.Initialize(
|
||||
graph_config,
|
||||
{{"metadata_extractor",
|
||||
MakePacket<ModelMetadataExtractor>(std::move(*metadata_extractor))}}));
|
||||
MP_RETURN_IF_ERROR(graph.StartRun({}));
|
||||
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
|
||||
"text", MakePacket<std::string>(text).At(Timestamp(0))));
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilIdle());
|
||||
|
||||
if (output_packets.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"output_packets has size $0, expected 1", output_packets.size()));
|
||||
}
|
||||
const std::vector<Tensor>& tensor_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
if (tensor_vec.size() != 1) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor_vec has size $0, expected $1", tensor_vec.size(), 1));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kInt32) {
|
||||
return absl::InvalidArgumentError("Expected tensor element type kInt32");
|
||||
}
|
||||
auto* buffer = tensor_vec[0].GetCpuReadView().buffer<int>();
|
||||
std::vector<int> result(buffer, buffer + kMaxSeqLen);
|
||||
MP_RETURN_IF_ERROR(graph.CloseAllPacketSources());
|
||||
MP_RETURN_IF_ERROR(graph.WaitUntilDone());
|
||||
return result;
|
||||
}
|
||||
|
||||
TEST(RegexPreprocessorCalculatorTest, TextClassifierModel) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<int> processed_tensor_values,
|
||||
RunRegexPreprocessorCalculator("This is the best movie I’ve seen in "
|
||||
"recent years. Strongly recommend it!"));
|
||||
static const int expected_result[kMaxSeqLen] = {
|
||||
1, 2, 9, 4, 118, 20, 2, 2, 110, 11, 1136, 153, 2, 386, 12};
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
TEST(RegexPreprocessorCalculatorTest, LongInput) {
|
||||
std::stringstream long_input;
|
||||
long_input << "This is the best";
|
||||
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||
long_input << " best";
|
||||
}
|
||||
long_input << "movie I’ve seen in recent years. Strongly recommend it!";
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::vector<int> processed_tensor_values,
|
||||
RunRegexPreprocessorCalculator(long_input.str()));
|
||||
std::vector<int> expected_result = {1, 2, 9, 4, 118};
|
||||
// "best" id
|
||||
expected_result.resize(kMaxSeqLen, 118);
|
||||
EXPECT_THAT(processed_tensor_values, ElementsAreArray(expected_result));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -67,9 +67,7 @@ absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) {
|
|||
"tensor_vec has size $0, expected 1", tensor_vec.size()));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||
Tensor::ElementType::kChar));
|
||||
return absl::InvalidArgumentError("Expected tensor element type kChar");
|
||||
}
|
||||
const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>();
|
||||
return std::string(buffer, text.length());
|
||||
|
|
|
@ -88,9 +88,7 @@ RunUniversalSentenceEncoderPreprocessorCalculator(absl::string_view text) {
|
|||
kNumInputTensorsForUniversalSentenceEncoder));
|
||||
}
|
||||
if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) {
|
||||
return absl::InvalidArgumentError(absl::Substitute(
|
||||
"tensor has element type $0, expected $1", tensor_vec[0].element_type(),
|
||||
Tensor::ElementType::kChar));
|
||||
return absl::InvalidArgumentError("Expected tensor element type kChar");
|
||||
}
|
||||
std::vector<std::string> results;
|
||||
for (int i = 0; i < kNumInputTensorsForUniversalSentenceEncoder; ++i) {
|
||||
|
|
|
@ -91,9 +91,9 @@ class TfLiteModelCalculator : public CalculatorBase {
|
|||
tflite::DefaultErrorReporter());
|
||||
model = tflite::FlatBufferModel::BuildFromAllocation(
|
||||
std::move(model_allocation), tflite::DefaultErrorReporter());
|
||||
#elif
|
||||
#else
|
||||
return absl::FailedPreconditionError(
|
||||
"Loading by file descriptor is not supported on this platform.")
|
||||
"Loading by file descriptor is not supported on this platform.");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -378,8 +378,11 @@ cc_library(
|
|||
],
|
||||
}),
|
||||
deps = [
|
||||
":gl_texture_buffer",
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage",
|
||||
":image_frame_view",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -78,7 +78,7 @@ absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) {
|
|||
16,
|
||||
0};
|
||||
|
||||
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs];
|
||||
pixel_format_ = [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_2_1];
|
||||
}
|
||||
if (!pixel_format_) {
|
||||
// On several Forge machines, the default config fails. For now let's do
|
||||
|
|
|
@ -65,6 +65,7 @@ class GlTextureView {
|
|||
friend class GpuBuffer;
|
||||
friend class GlTextureBuffer;
|
||||
friend class GpuBufferStorageCvPixelBuffer;
|
||||
friend class GpuBufferStorageAhwb;
|
||||
GlTextureView(GlContext* context, GLenum target, GLuint name, int width,
|
||||
int height, std::shared_ptr<GpuBuffer> gpu_buffer, int plane,
|
||||
DetachFn detach, DoneWritingFn done_writing)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage_ahwb.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_storage_image_frame.h"
|
||||
#include "mediapipe/gpu/gpu_test_base.h"
|
||||
#include "stb_image.h"
|
||||
|
|
|
@ -23,13 +23,17 @@ from mediapipe.model_maker.python.core.data import dataset as ds
|
|||
class ClassificationDataset(ds.Dataset):
|
||||
"""DataLoader for classification models."""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_to_label: Any):
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int, index_by_label: Any):
|
||||
super().__init__(dataset, size)
|
||||
self.index_to_label = index_to_label
|
||||
self._index_by_label = index_by_label
|
||||
|
||||
@property
|
||||
def num_classes(self: ds._DatasetT) -> int:
|
||||
return len(self.index_to_label)
|
||||
return len(self._index_by_label)
|
||||
|
||||
@property
|
||||
def index_by_label(self: ds._DatasetT) -> Any:
|
||||
return self._index_by_label
|
||||
|
||||
def split(self: ds._DatasetT,
|
||||
fraction: float) -> Tuple[ds._DatasetT, ds._DatasetT]:
|
||||
|
@ -44,4 +48,4 @@ class ClassificationDataset(ds.Dataset):
|
|||
Returns:
|
||||
The splitted two sub datasets.
|
||||
"""
|
||||
return self._split(fraction, self.index_to_label)
|
||||
return self._split(fraction, self._index_by_label)
|
||||
|
|
|
@ -12,45 +12,59 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Tuple, TypeVar
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||
|
||||
_DatasetT = TypeVar(
|
||||
'_DatasetT', bound='ClassificationDatasetTest.MagicClassificationDataset')
|
||||
|
||||
class ClassificationDataLoaderTest(tf.test.TestCase):
|
||||
|
||||
class ClassificationDatasetTest(tf.test.TestCase):
|
||||
|
||||
def test_split(self):
|
||||
|
||||
class MagicClassificationDataLoader(
|
||||
class MagicClassificationDataset(
|
||||
classification_dataset.ClassificationDataset):
|
||||
"""A mock classification dataset class for testing purpose.
|
||||
|
||||
def __init__(self, dataset, size, index_to_label, value):
|
||||
super(MagicClassificationDataLoader,
|
||||
self).__init__(dataset, size, index_to_label)
|
||||
Attributes:
|
||||
value: A value variable stored by the mock dataset class for testing.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: tf.data.Dataset, size: int,
|
||||
index_by_label: Any, value: Any):
|
||||
super().__init__(
|
||||
dataset=dataset, size=size, index_by_label=index_by_label)
|
||||
self.value = value
|
||||
|
||||
def split(self, fraction):
|
||||
return self._split(fraction, self.index_to_label, self.value)
|
||||
def split(self, fraction: float) -> Tuple[_DatasetT, _DatasetT]:
|
||||
return self._split(fraction, self.index_by_label, self.value)
|
||||
|
||||
# Some dummy inputs.
|
||||
magic_value = 42
|
||||
num_classes = 2
|
||||
index_to_label = (False, True)
|
||||
index_by_label = (False, True)
|
||||
|
||||
# Create data loader from sample data.
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = MagicClassificationDataLoader(ds, len(ds), index_to_label,
|
||||
magic_value)
|
||||
data = MagicClassificationDataset(
|
||||
dataset=ds,
|
||||
size=len(ds),
|
||||
index_by_label=index_by_label,
|
||||
value=magic_value)
|
||||
|
||||
# Train/Test data split.
|
||||
fraction = .25
|
||||
train_data, test_data = data.split(fraction)
|
||||
train_data, test_data = data.split(fraction=fraction)
|
||||
|
||||
# `split` should return instances of child DataLoader.
|
||||
self.assertIsInstance(train_data, MagicClassificationDataLoader)
|
||||
self.assertIsInstance(test_data, MagicClassificationDataLoader)
|
||||
self.assertIsInstance(train_data, MagicClassificationDataset)
|
||||
self.assertIsInstance(test_data, MagicClassificationDataset)
|
||||
|
||||
# Make sure number of entries are right.
|
||||
self.assertEqual(len(train_data.gen_tf_dataset()), len(train_data))
|
||||
|
@ -59,7 +73,7 @@ class ClassificationDataLoaderTest(tf.test.TestCase):
|
|||
|
||||
# Make sure attributes propagated correctly.
|
||||
self.assertEqual(train_data.num_classes, num_classes)
|
||||
self.assertEqual(test_data.index_to_label, index_to_label)
|
||||
self.assertEqual(test_data.index_by_label, index_by_label)
|
||||
self.assertEqual(train_data.value, magic_value)
|
||||
self.assertEqual(test_data.value, magic_value)
|
||||
|
||||
|
|
|
@ -29,22 +29,22 @@ from mediapipe.model_maker.python.core.tasks import custom_model
|
|||
class Classifier(custom_model.CustomModel):
|
||||
"""An abstract base class that represents a TensorFlow classifier."""
|
||||
|
||||
def __init__(self, model_spec: Any, index_to_label: List[str], shuffle: bool,
|
||||
def __init__(self, model_spec: Any, index_by_label: List[str], shuffle: bool,
|
||||
full_train: bool):
|
||||
"""Initilizes a classifier with its specifications.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_to_label: A list that map from index to label class name.
|
||||
index_by_label: A list that map from index to label class name.
|
||||
shuffle: Whether the dataset should be shuffled.
|
||||
full_train: If true, train the model end-to-end including the backbone
|
||||
and the classification layers on top. Otherwise, only train the top
|
||||
classification layers.
|
||||
"""
|
||||
super(Classifier, self).__init__(model_spec, shuffle)
|
||||
self._index_to_label = index_to_label
|
||||
self._index_by_label = index_by_label
|
||||
self._full_train = full_train
|
||||
self._num_classes = len(index_to_label)
|
||||
self._num_classes = len(index_by_label)
|
||||
|
||||
def evaluate(self, data: dataset.Dataset, batch_size: int = 32) -> Any:
|
||||
"""Evaluates the classifier with the provided evaluation dataset.
|
||||
|
@ -74,4 +74,4 @@ class Classifier(custom_model.CustomModel):
|
|||
label_filepath = os.path.join(export_dir, label_filename)
|
||||
tf.compat.v1.logging.info('Saving labels in %s', label_filepath)
|
||||
with tf.io.gfile.GFile(label_filepath, 'w') as f:
|
||||
f.write('\n'.join(self._index_to_label))
|
||||
f.write('\n'.join(self._index_by_label))
|
||||
|
|
|
@ -36,10 +36,10 @@ class ClassifierTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(ClassifierTest, self).setUp()
|
||||
index_to_label = ['cat', 'dog']
|
||||
index_by_label = ['cat', 'dog']
|
||||
self.model = MockClassifier(
|
||||
model_spec=None,
|
||||
index_to_label=index_to_label,
|
||||
index_by_label=index_by_label,
|
||||
shuffle=False,
|
||||
full_train=False)
|
||||
self.model.model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
|
|
@ -21,8 +21,6 @@ import abc
|
|||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Dependency imports
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.data import dataset
|
||||
|
@ -77,9 +75,9 @@ class CustomModel(abc.ABC):
|
|||
tflite_filepath = os.path.join(export_dir, tflite_filename)
|
||||
# TODO: Populate metadata to the exported TFLite model.
|
||||
model_util.export_tflite(
|
||||
self._model,
|
||||
tflite_filepath,
|
||||
quantization_config,
|
||||
model=self._model,
|
||||
tflite_filepath=tflite_filepath,
|
||||
quantization_config=quantization_config,
|
||||
preprocess=preprocess)
|
||||
tf.compat.v1.logging.info(
|
||||
'TensorFlow Lite model exported successfully: %s' % tflite_filepath)
|
||||
|
|
|
@ -40,8 +40,8 @@ class CustomModelTest(tf.test.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
super(CustomModelTest, self).setUp()
|
||||
self.model = MockCustomModel(model_spec=None, shuffle=False)
|
||||
self.model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
self._model = MockCustomModel(model_spec=None, shuffle=False)
|
||||
self._model._model = test_util.build_model(input_shape=[4], num_classes=2)
|
||||
|
||||
def _check_nonempty_file(self, filepath):
|
||||
self.assertTrue(os.path.isfile(filepath))
|
||||
|
@ -49,7 +49,7 @@ class CustomModelTest(tf.test.TestCase):
|
|||
|
||||
def test_export_tflite(self):
|
||||
export_path = os.path.join(self.get_temp_dir(), 'export/')
|
||||
self.model.export_tflite(export_dir=export_path)
|
||||
self._model.export_tflite(export_dir=export_path)
|
||||
self._check_nonempty_file(os.path.join(export_path, 'model.tflite'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -31,20 +31,6 @@ py_library(
|
|||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_preprocessing",
|
||||
srcs = ["image_preprocessing.py"],
|
||||
srcs_version = "PY3",
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_preprocessing_test",
|
||||
srcs = ["image_preprocessing_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY3",
|
||||
deps = [":image_preprocessing"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model_util",
|
||||
srcs = ["model_util.py"],
|
||||
|
|
|
@ -56,7 +56,7 @@ class FocalLoss(tf.keras.losses.Loss):
|
|||
class_weight: A weight to apply to the loss, one for each class. The
|
||||
weight is applied for each input where the ground truth label matches.
|
||||
"""
|
||||
super(tf.keras.losses.Loss, self).__init__()
|
||||
super().__init__()
|
||||
# Used for clipping min/max values of probability values in y_pred to avoid
|
||||
# NaNs and Infs in computation.
|
||||
self._epsilon = 1e-7
|
||||
|
|
|
@ -104,8 +104,8 @@ def export_tflite(
|
|||
quantization_config: Configuration for post-training quantization.
|
||||
supported_ops: A list of supported ops in the converted TFLite file.
|
||||
preprocess: A callable to preprocess the representative dataset for
|
||||
quantization. The callable takes three arguments in order: feature,
|
||||
label, and is_training.
|
||||
quantization. The callable takes three arguments in order: feature, label,
|
||||
and is_training.
|
||||
"""
|
||||
if tflite_filepath is None:
|
||||
raise ValueError(
|
||||
|
|
|
@ -100,7 +100,8 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model = test_util.build_model(input_shape=[input_dim], num_classes=2)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model.tflite')
|
||||
model_util.export_tflite(model, tflite_file)
|
||||
self._test_tflite(model, tflite_file, input_dim)
|
||||
test_util.test_tflite(
|
||||
keras_model=model, tflite_file=tflite_file, size=[1, input_dim])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(
|
||||
|
@ -121,27 +122,20 @@ class ModelUtilTest(tf.test.TestCase, parameterized.TestCase):
|
|||
input_dim = 16
|
||||
num_classes = 2
|
||||
max_input_value = 5
|
||||
model = test_util.build_model([input_dim], num_classes)
|
||||
model = test_util.build_model(
|
||||
input_shape=[input_dim], num_classes=num_classes)
|
||||
tflite_file = os.path.join(self.get_temp_dir(), 'model_quantized.tflite')
|
||||
|
||||
model_util.export_tflite(model, tflite_file, config)
|
||||
self._test_tflite(
|
||||
model, tflite_file, input_dim, max_input_value, atol=1e-00)
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
def _test_tflite(self,
|
||||
keras_model: tf.keras.Model,
|
||||
tflite_model_file: str,
|
||||
input_dim: int,
|
||||
max_input_value: int = 1000,
|
||||
atol: float = 1e-04):
|
||||
random_input = test_util.create_random_sample(
|
||||
size=[1, input_dim], high=max_input_value)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
model_util.export_tflite(
|
||||
model=model, tflite_filepath=tflite_file, quantization_config=config)
|
||||
self.assertTrue(
|
||||
test_util.is_same_output(
|
||||
tflite_model_file, keras_model, random_input, atol=atol))
|
||||
test_util.test_tflite(
|
||||
keras_model=model,
|
||||
tflite_file=tflite_file,
|
||||
size=[1, input_dim],
|
||||
high=max_input_value,
|
||||
atol=1e-00))
|
||||
self.assertNear(os.path.getsize(tflite_file), model_size, 300)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -92,3 +92,32 @@ def is_same_output(tflite_file: str,
|
|||
keras_output = keras_model.predict_on_batch(input_tensors)
|
||||
|
||||
return np.allclose(lite_output, keras_output, atol=atol)
|
||||
|
||||
|
||||
def test_tflite(keras_model: tf.keras.Model,
|
||||
tflite_file: str,
|
||||
size: Union[int, List[int]],
|
||||
high: float = 1,
|
||||
atol: float = 1e-04) -> bool:
|
||||
"""Verifies if the output of TFLite model and TF Keras model are identical.
|
||||
|
||||
Args:
|
||||
keras_model: Input TensorFlow Keras model.
|
||||
tflite_file: Input TFLite model file.
|
||||
size: Size of the input tesnor.
|
||||
high: Higher boundary of the values in input tensors.
|
||||
atol: Absolute tolerance of the difference between the outputs of Keras
|
||||
model and TFLite model.
|
||||
|
||||
Returns:
|
||||
True if the output of TFLite model and TF Keras model are identical.
|
||||
Otherwise, False.
|
||||
"""
|
||||
random_input = create_random_sample(size=size, high=high)
|
||||
random_input = tf.convert_to_tensor(random_input)
|
||||
|
||||
return is_same_output(
|
||||
tflite_file=tflite_file,
|
||||
keras_model=keras_model,
|
||||
input_tensors=random_input,
|
||||
atol=atol)
|
||||
|
|
33
mediapipe/model_maker/python/vision/core/BUILD
Normal file
33
mediapipe/model_maker/python/vision/core/BUILD
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Placeholder for internal Python strict library compatibility macro.
|
||||
# Placeholder for internal Python strict test compatibility macro.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_visibility = ["//mediapipe:__subpackages__"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "image_preprocessing",
|
||||
srcs = ["image_preprocessing.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "image_preprocessing_test",
|
||||
srcs = ["image_preprocessing_test.py"],
|
||||
deps = [":image_preprocessing"],
|
||||
)
|
13
mediapipe/model_maker/python/vision/core/__init__.py
Normal file
13
mediapipe/model_maker/python/vision/core/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -13,11 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""ImageNet preprocessing."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import tensorflow as tf
|
||||
|
||||
IMAGE_SIZE = 224
|
|
@ -12,15 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Dependency imports
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||
|
||||
|
||||
def _get_preprocessed_image(preprocessor, is_training=False):
|
|
@ -78,9 +78,9 @@ py_library(
|
|||
":train_image_classifier_lib",
|
||||
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||
"//mediapipe/model_maker/python/core/tasks:classifier",
|
||||
"//mediapipe/model_maker/python/core/utils:image_preprocessing",
|
||||
"//mediapipe/model_maker/python/core/utils:model_util",
|
||||
"//mediapipe/model_maker/python/core/utils:quantization",
|
||||
"//mediapipe/model_maker/python/vision/core:image_preprocessing",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
|
@ -24,12 +24,12 @@ from mediapipe.model_maker.python.core.data import classification_dataset
|
|||
|
||||
|
||||
def _load_image(path: str) -> tf.Tensor:
|
||||
"""Loads image."""
|
||||
"""Loads a jpeg/png image and returns an image tensor."""
|
||||
image_raw = tf.io.read_file(path)
|
||||
image_tensor = tf.cond(
|
||||
tf.image.is_jpeg(image_raw),
|
||||
lambda: tf.image.decode_jpeg(image_raw, channels=3),
|
||||
lambda: tf.image.decode_png(image_raw, channels=3))
|
||||
tf.io.is_jpeg(image_raw),
|
||||
lambda: tf.io.decode_jpeg(image_raw, channels=3),
|
||||
lambda: tf.io.decode_png(image_raw, channels=3))
|
||||
return image_tensor
|
||||
|
||||
|
||||
|
@ -60,11 +60,10 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
|
||||
Args:
|
||||
dirname: Name of the directory containing the data files.
|
||||
shuffle: boolean, if shuffle, random shuffle data.
|
||||
shuffle: boolean, if true, random shuffle data.
|
||||
|
||||
Returns:
|
||||
Dataset containing images and labels and other related info.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input data directory is empty.
|
||||
"""
|
||||
|
@ -85,55 +84,26 @@ class Dataset(classification_dataset.ClassificationDataset):
|
|||
name for name in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, name)))
|
||||
all_label_size = len(label_names)
|
||||
label_to_index = dict(
|
||||
index_by_label = dict(
|
||||
(name, index) for index, name in enumerate(label_names))
|
||||
all_image_labels = [
|
||||
label_to_index[os.path.basename(os.path.dirname(path))]
|
||||
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||
for path in all_image_paths
|
||||
]
|
||||
|
||||
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||
|
||||
autotune = tf.data.AUTOTUNE
|
||||
image_ds = path_ds.map(_load_image, num_parallel_calls=autotune)
|
||||
image_ds = path_ds.map(_load_image, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
|
||||
# Loads label.
|
||||
# Load label
|
||||
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||
tf.cast(all_image_labels, tf.int64))
|
||||
|
||||
# Creates a dataset if (image, label) pairs.
|
||||
# Create a dataset if (image, label) pairs
|
||||
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||
|
||||
tf.compat.v1.logging.info(
|
||||
'Load image with size: %d, num_label: %d, labels: %s.', all_image_size,
|
||||
all_label_size, ', '.join(label_names))
|
||||
return Dataset(image_label_ds, all_image_size, label_names)
|
||||
|
||||
@classmethod
|
||||
def load_tf_dataset(
|
||||
cls, name: str
|
||||
) -> Tuple[Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset],
|
||||
Optional[classification_dataset.ClassificationDataset]]:
|
||||
"""Loads data from tensorflow_datasets.
|
||||
|
||||
Args:
|
||||
name: the registered name of the tfds.core.DatasetBuilder. Refer to the
|
||||
documentation of tfds.load for more details.
|
||||
|
||||
Returns:
|
||||
A tuple of Datasets for the train/validation/test.
|
||||
|
||||
Raises:
|
||||
ValueError: if the input tf dataset does not have train/validation/test
|
||||
labels.
|
||||
"""
|
||||
data, info = tfds.load(name, with_info=True)
|
||||
if 'label' not in info.features:
|
||||
raise ValueError('info.features need to contain \'label\' key.')
|
||||
label_names = info.features['label'].names
|
||||
|
||||
train_data = _create_data('train', data, info, label_names)
|
||||
validation_data = _create_data('validation', data, info, label_names)
|
||||
test_data = _create_data('test', data, info, label_names)
|
||||
return train_data, validation_data, test_data
|
||||
return Dataset(
|
||||
dataset=image_label_ds, size=all_image_size, index_by_label=label_names)
|
||||
|
|
|
@ -49,27 +49,27 @@ class DatasetTest(tf.test.TestCase):
|
|||
|
||||
def test_split(self):
|
||||
ds = tf.data.Dataset.from_tensor_slices([[0, 1], [1, 1], [0, 0], [1, 0]])
|
||||
data = dataset.Dataset(ds, 4, ['pos', 'neg'])
|
||||
train_data, test_data = data.split(0.5)
|
||||
data = dataset.Dataset(dataset=ds, size=4, index_by_label=['pos', 'neg'])
|
||||
train_data, test_data = data.split(fraction=0.5)
|
||||
|
||||
self.assertLen(train_data, 2)
|
||||
for i, elem in enumerate(train_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 1])).all())
|
||||
self.assertEqual(train_data.num_classes, 2)
|
||||
self.assertEqual(train_data.index_to_label, ['pos', 'neg'])
|
||||
self.assertEqual(train_data.index_by_label, ['pos', 'neg'])
|
||||
|
||||
self.assertLen(test_data, 2)
|
||||
for i, elem in enumerate(test_data._dataset):
|
||||
self.assertTrue((elem.numpy() == np.array([i, 0])).all())
|
||||
self.assertEqual(test_data.num_classes, 2)
|
||||
self.assertEqual(test_data.index_to_label, ['pos', 'neg'])
|
||||
self.assertEqual(test_data.index_by_label, ['pos', 'neg'])
|
||||
|
||||
def test_from_folder(self):
|
||||
data = dataset.Dataset.from_folder(self.image_path)
|
||||
data = dataset.Dataset.from_folder(dirname=self.image_path)
|
||||
|
||||
self.assertLen(data, 2)
|
||||
self.assertEqual(data.num_classes, 2)
|
||||
self.assertEqual(data.index_to_label, ['daisy', 'tulips'])
|
||||
self.assertEqual(data.index_by_label, ['daisy', 'tulips'])
|
||||
for image, label in data.gen_tf_dataset():
|
||||
self.assertTrue(label.numpy() == 1 or label.numpy() == 0)
|
||||
if label.numpy() == 0:
|
||||
|
@ -88,19 +88,19 @@ class DatasetTest(tf.test.TestCase):
|
|||
self.assertIsInstance(train_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(train_data, 1034)
|
||||
self.assertEqual(train_data.num_classes, 3)
|
||||
self.assertEqual(train_data.index_to_label,
|
||||
self.assertEqual(train_data.index_by_label,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(validation_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(validation_data, 133)
|
||||
self.assertEqual(validation_data.num_classes, 3)
|
||||
self.assertEqual(validation_data.index_to_label,
|
||||
self.assertEqual(validation_data.index_by_label,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
self.assertIsInstance(test_data.gen_tf_dataset(), tf.data.Dataset)
|
||||
self.assertLen(test_data, 128)
|
||||
self.assertEqual(test_data.num_classes, 3)
|
||||
self.assertEqual(test_data.index_to_label,
|
||||
self.assertEqual(test_data.index_by_label,
|
||||
['angular_leaf_spot', 'bean_rust', 'healthy'])
|
||||
|
||||
|
||||
|
|
|
@ -20,9 +20,9 @@ import tensorflow_hub as hub
|
|||
|
||||
from mediapipe.model_maker.python.core.data import classification_dataset as classification_ds
|
||||
from mediapipe.model_maker.python.core.tasks import classifier
|
||||
from mediapipe.model_maker.python.core.utils import image_preprocessing
|
||||
from mediapipe.model_maker.python.core.utils import model_util
|
||||
from mediapipe.model_maker.python.core.utils import quantization
|
||||
from mediapipe.model_maker.python.vision.core import image_preprocessing
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters as hp
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.image_classifier import train_image_classifier_lib
|
||||
|
@ -31,18 +31,18 @@ from mediapipe.model_maker.python.vision.image_classifier import train_image_cla
|
|||
class ImageClassifier(classifier.Classifier):
|
||||
"""ImageClassifier for building image classification model."""
|
||||
|
||||
def __init__(self, model_spec: ms.ModelSpec, index_to_label: List[Any],
|
||||
def __init__(self, model_spec: ms.ModelSpec, index_by_label: List[Any],
|
||||
hparams: hp.HParams):
|
||||
"""Initializes ImageClassifier class.
|
||||
|
||||
Args:
|
||||
model_spec: Specification for the model.
|
||||
index_to_label: A list that maps from index to label class name.
|
||||
index_by_label: A list that maps from index to label class name.
|
||||
hparams: The hyperparameters for training image classifier.
|
||||
"""
|
||||
super(ImageClassifier, self).__init__(
|
||||
super().__init__(
|
||||
model_spec=model_spec,
|
||||
index_to_label=index_to_label,
|
||||
index_by_label=index_by_label,
|
||||
shuffle=hparams.shuffle,
|
||||
full_train=hparams.do_fine_tuning)
|
||||
self._hparams = hparams
|
||||
|
@ -81,7 +81,7 @@ class ImageClassifier(classifier.Classifier):
|
|||
spec = ms.SupportedModels.get(model_spec)
|
||||
image_classifier = cls(
|
||||
model_spec=spec,
|
||||
index_to_label=train_data.index_to_label,
|
||||
index_by_label=train_data.index_by_label,
|
||||
hparams=hparams)
|
||||
|
||||
image_classifier._create_model()
|
||||
|
|
|
@ -60,31 +60,16 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
model_spec=image_classifier.SupportedModels.MOBILENET_V2,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='resnet_50',
|
||||
model_spec=image_classifier.SupportedModels.RESNET_50,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE0,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite1',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE1,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite2',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE2,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite3',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE3,
|
||||
hparams=image_classifier.HParams(
|
||||
train_epochs=1, batch_size=1, shuffle=True)),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite4',
|
||||
model_spec=image_classifier.SupportedModels.EFFICIENTNET_LITE4,
|
||||
|
|
|
@ -48,34 +48,17 @@ mobilenet_v2_spec = functools.partial(
|
|||
uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||
name='mobilenet_v2')
|
||||
|
||||
resnet_50_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
||||
name='resnet_50')
|
||||
|
||||
efficientnet_lite0_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
||||
name='efficientnet_lite0')
|
||||
|
||||
efficientnet_lite1_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
||||
input_image_shape=[240, 240],
|
||||
name='efficientnet_lite1')
|
||||
|
||||
efficientnet_lite2_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
||||
input_image_shape=[260, 260],
|
||||
name='efficientnet_lite2')
|
||||
|
||||
efficientnet_lite3_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
||||
input_image_shape=[280, 280],
|
||||
name='efficientnet_lite3')
|
||||
|
||||
efficientnet_lite4_spec = functools.partial(
|
||||
ModelSpec,
|
||||
uri='https://tfhub.dev/tensorflow/efficientnet/lite4/feature-vector/2',
|
||||
|
@ -88,11 +71,8 @@ efficientnet_lite4_spec = functools.partial(
|
|||
class SupportedModels(enum.Enum):
|
||||
"""Image classifier model supported by model maker."""
|
||||
MOBILENET_V2 = mobilenet_v2_spec
|
||||
RESNET_50 = resnet_50_spec
|
||||
EFFICIENTNET_LITE0 = efficientnet_lite0_spec
|
||||
EFFICIENTNET_LITE1 = efficientnet_lite1_spec
|
||||
EFFICIENTNET_LITE2 = efficientnet_lite2_spec
|
||||
EFFICIENTNET_LITE3 = efficientnet_lite3_spec
|
||||
EFFICIENTNET_LITE4 = efficientnet_lite4_spec
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -30,36 +30,18 @@ class ModelSpecTest(tf.test.TestCase, parameterized.TestCase):
|
|||
expected_uri='https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4',
|
||||
expected_name='mobilenet_v2',
|
||||
expected_input_image_shape=[224, 224]),
|
||||
dict(
|
||||
testcase_name='resnet_50_spec_test',
|
||||
model_spec=ms.resnet_50_spec,
|
||||
expected_uri='https://tfhub.dev/google/imagenet/resnet_v2_50/feature_vector/4',
|
||||
expected_name='resnet_50',
|
||||
expected_input_image_shape=[224, 224]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite0_spec_test',
|
||||
model_spec=ms.efficientnet_lite0_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite0/feature-vector/2',
|
||||
expected_name='efficientnet_lite0',
|
||||
expected_input_image_shape=[224, 224]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite1_spec_test',
|
||||
model_spec=ms.efficientnet_lite1_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite1/feature-vector/2',
|
||||
expected_name='efficientnet_lite1',
|
||||
expected_input_image_shape=[240, 240]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite2_spec_test',
|
||||
model_spec=ms.efficientnet_lite2_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite2/feature-vector/2',
|
||||
expected_name='efficientnet_lite2',
|
||||
expected_input_image_shape=[260, 260]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite3_spec_test',
|
||||
model_spec=ms.efficientnet_lite3_spec,
|
||||
expected_uri='https://tfhub.dev/tensorflow/efficientnet/lite3/feature-vector/2',
|
||||
expected_name='efficientnet_lite3',
|
||||
expected_input_image_shape=[280, 280]),
|
||||
dict(
|
||||
testcase_name='efficientnet_lite4_spec_test',
|
||||
model_spec=ms.efficientnet_lite4_spec,
|
||||
|
|
|
@ -92,3 +92,29 @@ cc_library(
|
|||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# TODO: Investigate rewriting the build rule to only link
|
||||
# the Bert Preprocessor if it's needed.
|
||||
cc_library(
|
||||
name = "text_preprocessing_graph",
|
||||
srcs = ["text_preprocessing_graph.cc"],
|
||||
hdrs = ["text_preprocessing_graph.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:bert_preprocessor_calculator",
|
||||
"//mediapipe/calculators/tensor:bert_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:regex_preprocessor_calculator",
|
||||
"//mediapipe/calculators/tensor:regex_preprocessor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:text_to_tensor_calculator",
|
||||
"//mediapipe/framework:subgraph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -17,8 +17,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "landmarks_detection",
|
||||
hdrs = ["landmarks_detection.h"],
|
||||
name = "rect",
|
||||
hdrs = ["rect.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -13,26 +13,18 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
// Sturcts holding landmarks related data structure for hand landmarker, pose
|
||||
// detector, face mesher, etc.
|
||||
namespace mediapipe::tasks::components::containers {
|
||||
|
||||
// x and y are in [0,1] range with origin in top left in input image space.
|
||||
// If model provides z, z is in the same scale as x. origin is in the center
|
||||
// of the face.
|
||||
struct Landmark {
|
||||
float x;
|
||||
float y;
|
||||
float z;
|
||||
};
|
||||
|
||||
// [0, 1] range in input image space
|
||||
struct Bound {
|
||||
// Defines a rectangle, used e.g. as part of detection results or as input
|
||||
// region-of-interest.
|
||||
//
|
||||
// The coordinates are normalized wrt the image dimensions, i.e. generally in
|
||||
// [0,1] but they may exceed these bounds if describing a region overlapping the
|
||||
// image. The origin is on the top-left corner of the image.
|
||||
struct Rect {
|
||||
float left;
|
||||
float top;
|
||||
float right;
|
||||
|
@ -40,4 +32,4 @@ struct Bound {
|
|||
};
|
||||
|
||||
} // namespace mediapipe::tasks::components::containers
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_LANDMARKS_DETECTION_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_RECT_H_
|
|
@ -73,6 +73,7 @@ cc_library(
|
|||
srcs = ["model_task_graph.cc"],
|
||||
hdrs = ["model_task_graph.h"],
|
||||
deps = [
|
||||
":model_asset_bundle_resources",
|
||||
":model_resources",
|
||||
":model_resources_cache",
|
||||
":model_resources_calculator",
|
||||
|
@ -163,6 +164,7 @@ cc_library_with_tflite(
|
|||
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
|
||||
],
|
||||
deps = [
|
||||
":model_asset_bundle_resources",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "absl/strings/substitute.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
|
@ -39,12 +40,16 @@ ModelResourcesCache::ModelResourcesCache(
|
|||
graph_op_resolver_packet_ =
|
||||
api2::PacketAdopting<tflite::OpResolver>(std::move(graph_op_resolver));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
bool ModelResourcesCache::Exists(const std::string& tag) const {
|
||||
return model_resources_collection_.contains(tag);
|
||||
}
|
||||
|
||||
bool ModelResourcesCache::ModelAssetBundleExists(const std::string& tag) const {
|
||||
return model_asset_bundle_resources_collection_.contains(tag);
|
||||
}
|
||||
|
||||
absl::Status ModelResourcesCache::AddModelResources(
|
||||
std::unique_ptr<ModelResources> model_resources) {
|
||||
if (model_resources == nullptr) {
|
||||
|
@ -94,6 +99,62 @@ absl::StatusOr<const ModelResources*> ModelResourcesCache::GetModelResources(
|
|||
return model_resources_collection_.at(tag).get();
|
||||
}
|
||||
|
||||
absl::Status ModelResourcesCache::AddModelAssetBundleResources(
|
||||
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources) {
|
||||
if (model_asset_bundle_resources == nullptr) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"ModelAssetBundleResources object is null.",
|
||||
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||
}
|
||||
const std::string& tag = model_asset_bundle_resources->GetTag();
|
||||
if (tag.empty()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"ModelAssetBundleResources must have a non-empty tag.",
|
||||
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||
}
|
||||
if (ModelAssetBundleExists(tag)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::Substitute(
|
||||
"ModelAssetBundleResources with tag \"$0\" already exists.", tag),
|
||||
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||
}
|
||||
model_asset_bundle_resources_collection_.emplace(
|
||||
tag, std::move(model_asset_bundle_resources));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ModelResourcesCache::AddModelAssetBundleResourcesCollection(
|
||||
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
|
||||
model_asset_bundle_resources_collection) {
|
||||
for (auto& model_bundle_resources : model_asset_bundle_resources_collection) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
AddModelAssetBundleResources(std::move(model_bundle_resources)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<const ModelAssetBundleResources*>
|
||||
ModelResourcesCache::GetModelAssetBundleResources(
|
||||
const std::string& tag) const {
|
||||
if (tag.empty()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"ModelAssetBundleResources must be retrieved with a non-empty tag.",
|
||||
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||
}
|
||||
if (!ModelAssetBundleExists(tag)) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::Substitute(
|
||||
"ModelAssetBundleResources with tag \"$0\" does not exist.", tag),
|
||||
MediaPipeTasksStatus::kRunnerModelResourcesCacheServiceError);
|
||||
}
|
||||
return model_asset_bundle_resources_collection_.at(tag).get();
|
||||
}
|
||||
|
||||
absl::StatusOr<api2::Packet<tflite::OpResolver>>
|
||||
ModelResourcesCache::GetGraphOpResolverPacket() const {
|
||||
if (graph_op_resolver_packet_.IsEmpty()) {
|
||||
|
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
|
@ -46,6 +47,10 @@ class ModelResourcesCache {
|
|||
// Returns whether the tag exists in the model resources cache.
|
||||
bool Exists(const std::string& tag) const;
|
||||
|
||||
// Returns whether the tag of the model asset bundle exists in the model
|
||||
// resources cache.
|
||||
bool ModelAssetBundleExists(const std::string& tag) const;
|
||||
|
||||
// Adds a ModelResources object into the cache.
|
||||
// The tag of the ModelResources must be unique; the ownership of the
|
||||
// ModelResources will be transferred into the cache.
|
||||
|
@ -62,6 +67,23 @@ class ModelResourcesCache {
|
|||
absl::StatusOr<const ModelResources*> GetModelResources(
|
||||
const std::string& tag) const;
|
||||
|
||||
// Adds a ModelAssetBundleResources object into the cache.
|
||||
// The tag of the ModelAssetBundleResources must be unique; the ownership of
|
||||
// the ModelAssetBundleResources will be transferred into the cache.
|
||||
absl::Status AddModelAssetBundleResources(
|
||||
std::unique_ptr<ModelAssetBundleResources> model_asset_bundle_resources);
|
||||
|
||||
// Adds a collection of the ModelAssetBundleResources objects into the cache.
|
||||
// The tag of the each ModelAssetBundleResources must be unique; the ownership
|
||||
// of every ModelAssetBundleResources will be transferred into the cache.
|
||||
absl::Status AddModelAssetBundleResourcesCollection(
|
||||
std::vector<std::unique_ptr<ModelAssetBundleResources>>&
|
||||
model_asset_bundle_resources_collection);
|
||||
|
||||
// Retrieves a const ModelAssetBundleResources pointer by the unique tag.
|
||||
absl::StatusOr<const ModelAssetBundleResources*> GetModelAssetBundleResources(
|
||||
const std::string& tag) const;
|
||||
|
||||
// Retrieves the graph op resolver packet.
|
||||
absl::StatusOr<api2::Packet<tflite::OpResolver>> GetGraphOpResolverPacket()
|
||||
const;
|
||||
|
@ -73,6 +95,11 @@ class ModelResourcesCache {
|
|||
// A collection of ModelResources objects for the models in the graph.
|
||||
absl::flat_hash_map<std::string, std::unique_ptr<ModelResources>>
|
||||
model_resources_collection_;
|
||||
|
||||
// A collection of ModelAssetBundleResources objects for the model bundles in
|
||||
// the graph.
|
||||
absl::flat_hash_map<std::string, std::unique_ptr<ModelAssetBundleResources>>
|
||||
model_asset_bundle_resources_collection_;
|
||||
};
|
||||
|
||||
// Global service for mediapipe task model resources cache.
|
||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
|
@ -70,6 +71,17 @@ std::string CreateModelResourcesTag(const CalculatorGraphConfig::Node& node) {
|
|||
node_type);
|
||||
}
|
||||
|
||||
std::string CreateModelAssetBundleResourcesTag(
|
||||
const CalculatorGraphConfig::Node& node) {
|
||||
std::vector<std::string> names = absl::StrSplit(node.name(), "__");
|
||||
std::string node_type = node.calculator();
|
||||
std::replace(node_type.begin(), node_type.end(), '.', '_');
|
||||
absl::AsciiStrToLower(&node_type);
|
||||
return absl::StrFormat("%s_%s_model_asset_bundle_resources",
|
||||
names.back().empty() ? "unnamed" : names.back(),
|
||||
node_type);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Defines the mediapipe task inference unit as a MediaPipe subgraph that
|
||||
|
@ -168,6 +180,38 @@ absl::StatusOr<const ModelResources*> ModelTaskGraph::CreateModelResources(
|
|||
return model_resources_cache_service.GetObject().GetModelResources(tag);
|
||||
}
|
||||
|
||||
absl::StatusOr<const ModelAssetBundleResources*>
|
||||
ModelTaskGraph::CreateModelAssetBundleResources(
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file) {
|
||||
auto model_resources_cache_service = sc->Service(kModelResourcesCacheService);
|
||||
bool has_file_pointer_meta = external_file->has_file_pointer_meta();
|
||||
// if external file is set by file pointer, no need to add the model asset
|
||||
// bundle resources into the model resources service since the memory is
|
||||
// not owned by this model asset bundle resources.
|
||||
if (!model_resources_cache_service.IsAvailable() || has_file_pointer_meta) {
|
||||
ASSIGN_OR_RETURN(
|
||||
local_model_asset_bundle_resources_,
|
||||
ModelAssetBundleResources::Create("", std::move(external_file)));
|
||||
if (!has_file_pointer_meta) {
|
||||
LOG(WARNING)
|
||||
<< "A local ModelResources object is created. Please consider using "
|
||||
"ModelResourcesCacheService to cache the created ModelResources "
|
||||
"object in the CalculatorGraph.";
|
||||
}
|
||||
return local_model_asset_bundle_resources_.get();
|
||||
}
|
||||
const std::string tag =
|
||||
CreateModelAssetBundleResourcesTag(sc->OriginalNode());
|
||||
ASSIGN_OR_RETURN(
|
||||
auto model_bundle_resources,
|
||||
ModelAssetBundleResources::Create(tag, std::move(external_file)));
|
||||
MP_RETURN_IF_ERROR(
|
||||
model_resources_cache_service.GetObject().AddModelAssetBundleResources(
|
||||
std::move(model_bundle_resources)));
|
||||
return model_resources_cache_service.GetObject().GetModelAssetBundleResources(
|
||||
tag);
|
||||
}
|
||||
|
||||
GenericNode& ModelTaskGraph::AddInference(
|
||||
const ModelResources& model_resources,
|
||||
const proto::Acceleration& acceleration, Graph& graph) const {
|
||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/subgraph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
|
@ -78,6 +79,35 @@ class ModelTaskGraph : public Subgraph {
|
|||
absl::StatusOr<const ModelResources*> CreateModelResources(
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
||||
|
||||
// If the model resources graph service is available, creates a model asset
|
||||
// bundle resources object from the subgraph context, and caches the created
|
||||
// model asset bundle resources into the model resources graph service on
|
||||
// success. Otherwise, creates a local model asset bundle resources object
|
||||
// that can only be used in the graph construction stage. The returned model
|
||||
// resources pointer will provide graph authors with the access to extracted
|
||||
// model files.
|
||||
template <typename Options>
|
||||
absl::StatusOr<const ModelAssetBundleResources*>
|
||||
CreateModelAssetBundleResources(SubgraphContext* sc) {
|
||||
auto external_file = std::make_unique<proto::ExternalFile>();
|
||||
external_file->Swap(sc->MutableOptions<Options>()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset());
|
||||
return CreateModelAssetBundleResources(sc, std::move(external_file));
|
||||
}
|
||||
|
||||
// If the model resources graph service is available, creates a model asset
|
||||
// bundle resources object from the subgraph context, and caches the created
|
||||
// model asset bundle resources into the model resources graph service on
|
||||
// success. Otherwise, creates a local model asset bundle resources object
|
||||
// that can only be used in the graph construction stage. Note that the
|
||||
// external file contents will be moved into the model asset bundle resources
|
||||
// object on creation. The returned model asset bundle resources pointer will
|
||||
// provide graph authors with the access to extracted model files.
|
||||
absl::StatusOr<const ModelAssetBundleResources*>
|
||||
CreateModelAssetBundleResources(
|
||||
SubgraphContext* sc, std::unique_ptr<proto::ExternalFile> external_file);
|
||||
|
||||
// Inserts a mediapipe task inference subgraph into the provided
|
||||
// GraphBuilder. The returned node provides the following interfaces to the
|
||||
// the rest of the graph:
|
||||
|
@ -95,6 +125,9 @@ class ModelTaskGraph : public Subgraph {
|
|||
|
||||
private:
|
||||
std::unique_ptr<ModelResources> local_model_resources_;
|
||||
|
||||
std::unique_ptr<ModelAssetBundleResources>
|
||||
local_model_asset_bundle_resources_;
|
||||
};
|
||||
|
||||
} // namespace core
|
||||
|
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/cleanup/cleanup.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
|
@ -162,12 +164,16 @@ absl::Status ExtractFilesfromZipFile(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void SetExternalFile(const std::string_view& file_content,
|
||||
core::proto::ExternalFile* model_file) {
|
||||
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
|
||||
|
||||
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
|
||||
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
|
||||
void SetExternalFile(const absl::string_view& file_content,
|
||||
core::proto::ExternalFile* model_file, bool is_copy) {
|
||||
if (is_copy) {
|
||||
std::string str_content{file_content};
|
||||
model_file->set_file_content(str_content);
|
||||
} else {
|
||||
auto pointer = reinterpret_cast<uint64_t>(file_content.data());
|
||||
model_file->mutable_file_pointer_meta()->set_pointer(pointer);
|
||||
model_file->mutable_file_pointer_meta()->set_length(file_content.length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace metadata
|
||||
|
|
|
@ -35,10 +35,13 @@ absl::Status ExtractFilesfromZipFile(
|
|||
const char* buffer_data, const size_t buffer_size,
|
||||
absl::flat_hash_map<std::string, absl::string_view>* files);
|
||||
|
||||
// Set file_pointer_meta in ExternalFile which is the pointer points to location
|
||||
// of a file in memory by file_content.
|
||||
void SetExternalFile(const std::string_view& file_content,
|
||||
core::proto::ExternalFile* model_file);
|
||||
// Set the ExternalFile object by file_content in memory. By default,
|
||||
// `is_copy=false` which means to set `file_pointer_meta` in ExternalFile which
|
||||
// is the pointer points to location of a file in memory. Otherwise, if
|
||||
// `is_copy=true`, copy the memory into `file_content` in ExternalFile.
|
||||
void SetExternalFile(const absl::string_view& file_content,
|
||||
core::proto::ExternalFile* model_file,
|
||||
bool is_copy = false);
|
||||
|
||||
} // namespace metadata
|
||||
} // namespace tasks
|
||||
|
|
84
mediapipe/tasks/cc/text/text_classifier/BUILD
Normal file
84
mediapipe/tasks/cc/text/text_classifier/BUILD
Normal file
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier_graph",
|
||||
srcs = ["text_classifier_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/tasks/cc/components:text_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:text_preprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources_calculator",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:model_resources_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier",
|
||||
srcs = ["text_classifier.cc"],
|
||||
hdrs = ["text_classifier.h"],
|
||||
deps = [
|
||||
":text_classifier_graph",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:base_task_api",
|
||||
"//mediapipe/tasks/cc/core:task_api_factory",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "text_classifier_test_utils",
|
||||
srcs = ["text_classifier_test_utils.cc"],
|
||||
hdrs = ["text_classifier_test_utils.h"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite:mutable_op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite:type_to_tflitetype",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||
],
|
||||
)
|
30
mediapipe/tasks/cc/text/text_classifier/proto/BUILD
Normal file
30
mediapipe/tasks/cc/text/text_classifier/proto/BUILD
Normal file
|
@ -0,0 +1,30 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "text_classifier_graph_options_proto",
|
||||
srcs = ["text_classifier_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.text.text_classifier.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.tasks.text.textclassifier.proto";
|
||||
option java_outer_classname = "TextClassifierGraphOptionsProto";
|
||||
|
||||
message TextClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional TextClassifierGraphOptions ext = 462704549;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
}
|
104
mediapipe/tasks/cc/text/text_classifier/text_classifier.cc
Normal file
104
mediapipe/tasks/cc/text/text_classifier/text_classifier.cc
Normal file
|
@ -0,0 +1,104 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kTextStreamName[] = "text_in";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// type "TextClassifierGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<proto::TextClassifierGraphOptions> options) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
|
||||
graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
|
||||
subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
graph.Out(kClassificationResultTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
// Converts the user-facing TextClassifierOptions struct to the internal
|
||||
// TextClassifierGraphOptions proto.
|
||||
std::unique_ptr<proto::TextClassifierGraphOptions>
|
||||
ConvertTextClassifierOptionsToProto(TextClassifierOptions* options) {
|
||||
auto options_proto = std::make_unique<proto::TextClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
|
||||
std::unique_ptr<TextClassifierOptions> options) {
|
||||
auto options_proto = ConvertTextClassifierOptionsToProto(options.get());
|
||||
return core::TaskApiFactory::Create<TextClassifier,
|
||||
proto::TextClassifierGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver));
|
||||
}
|
||||
|
||||
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
|
||||
absl::string_view text) {
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
runner_->Process(
|
||||
{{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
|
||||
return output_packets[kClassificationResultStreamName]
|
||||
.Get<ClassificationResult>();
|
||||
}
|
||||
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
96
mediapipe/tasks/cc/text/text_classifier/text_classifier.h
Normal file
96
mediapipe/tasks/cc/text/text_classifier/text_classifier.h
Normal file
|
@ -0,0 +1,96 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
|
||||
// The options for configuring a MediaPipe text classifier task.
|
||||
struct TextClassifierOptions {
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// file with metadata, accelerator options, op resolver, etc.
|
||||
tasks::core::BaseOptions base_options;
|
||||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
};
|
||||
|
||||
// Performs classification on text.
|
||||
//
|
||||
// This API expects a TFLite model with (optional) TFLite Model Metadata that
|
||||
// contains the mandatory (described below) input tensors, output tensor,
|
||||
// and the optional (but recommended) label items as AssociatedFiles with type
|
||||
// TENSOR_AXIS_LABELS per output classification tensor. Metadata is required for
|
||||
// models with int32 input tensors because it contains the input process unit
|
||||
// for the model's Tokenizer. No metadata is required for models with string
|
||||
// input tensors.
|
||||
//
|
||||
// Input tensors:
|
||||
// (kTfLiteInt32)
|
||||
// - 3 input tensors of size `[batch_size x bert_max_seq_len]` representing
|
||||
// the input ids, segment ids, and mask ids
|
||||
// - or 1 input tensor of size `[batch_size x max_seq_len]` representing the
|
||||
// input ids
|
||||
// or (kTfLiteString)
|
||||
// - 1 input tensor that is shapeless or has shape [1] containing the input
|
||||
// string
|
||||
// At least one output tensor with:
|
||||
// (kTfLiteFloat32/kBool)
|
||||
// - `[1 x N]` array with `N` represents the number of categories.
|
||||
// - optional (but recommended) label items as AssociatedFiles with type
|
||||
// TENSOR_AXIS_LABELS, containing one label per line. The first such
|
||||
// AssociatedFile (if any) is used to fill the `category_name` field of the
|
||||
// results. The `display_name` field is filled from the AssociatedFile (if
|
||||
// any) whose locale matches the `display_names_locale` field of the
|
||||
// `TextClassifierOptions` used at creation time ("en" by default, i.e.
|
||||
// English). If none of these are available, only the `index` field of the
|
||||
// results will be filled.
|
||||
class TextClassifier : core::BaseTaskApi {
|
||||
public:
|
||||
using BaseTaskApi::BaseTaskApi;
|
||||
|
||||
// Creates a TextClassifier from the provided `options`.
|
||||
static absl::StatusOr<std::unique_ptr<TextClassifier>> Create(
|
||||
std::unique_ptr<TextClassifierOptions> options);
|
||||
|
||||
// Performs classification on the input `text`.
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
absl::string_view text);
|
||||
|
||||
// Shuts down the TextClassifier when all the work is done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_H_
|
162
mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc
Normal file
162
mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc
Normal file
|
@ -0,0 +1,162 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/text_preprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/text_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/proto/text_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
|
||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
||||
constexpr char kTextTag[] = "TEXT";
|
||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "TextClassifierGraph" performs Natural Language classification (including
|
||||
// BERT-based text classification).
|
||||
// - Accepts input text and outputs classification results on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
// TEXT - std::string
|
||||
// Input text to perform classification on.
|
||||
//
|
||||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The aggregated classification result object that has 3 dimensions:
|
||||
// (classification head, classification timestamp, classification category).
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
|
||||
// input_stream: "TEXT:text_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
// file_name: "/path/to/model.tflite"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TextClassifierGraph : public core::ModelTaskGraph {
|
||||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(
|
||||
const ModelResources* model_resources,
|
||||
CreateModelResources<proto::TextClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
Source<ClassificationResult> classification_result_out,
|
||||
BuildTextClassifierTask(
|
||||
sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<std::string>(kTextTag)], graph));
|
||||
classification_result_out >>
|
||||
graph[Output<ClassificationResult>(kClassificationResultTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
private:
|
||||
// Adds a mediapipe TextClassifier task graph into the provided
|
||||
// builder::Graph instance. The TextClassifier task takes an input
|
||||
// text (std::string) and returns one classification result per output head
|
||||
// specified by the model.
|
||||
//
|
||||
// task_options: the mediapipe tasks TextClassifierGraphOptions proto.
|
||||
// model_resources: the ModelResources object initialized from a
|
||||
// TextClassifier model file with model metadata.
|
||||
// text_in: (std::string) stream to run text classification on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
|
||||
const proto::TextClassifierGraphOptions& task_options,
|
||||
const ModelResources& model_resources, Source<std::string> text_in,
|
||||
Graph& graph) {
|
||||
// Adds preprocessing calculators and connects them to the text input
|
||||
// stream.
|
||||
auto& preprocessing =
|
||||
graph.AddNode("mediapipe.tasks.components.TextPreprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(components::ConfigureTextPreprocessingSubgraph(
|
||||
model_resources,
|
||||
preprocessing.GetOptions<
|
||||
tasks::components::proto::TextPreprocessingGraphOptions>()));
|
||||
text_in >> preprocessing.In(kTextTag);
|
||||
|
||||
// Adds both InferenceCalculator and ModelResourcesCalculator.
|
||||
auto& inference = AddInference(
|
||||
model_resources, task_options.base_options().acceleration(), graph);
|
||||
// The metadata extractor side-output comes from the
|
||||
// ModelResourcesCalculator.
|
||||
inference.SideOut(kMetadataExtractorTag) >>
|
||||
preprocessing.SideIn(kMetadataExtractorTag);
|
||||
preprocessing.Out(kTensorsTag) >> inference.In(kTensorsTag);
|
||||
|
||||
// Adds postprocessing calculators and connects them to the graph output.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the aggregated classification result as the subgraph output
|
||||
// stream.
|
||||
return postprocessing[Output<ClassificationResult>(
|
||||
kClassificationResultTag)];
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::text::text_classifier::TextClassifierGraph);
|
||||
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
238
mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc
Normal file
238
mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc
Normal file
|
@ -0,0 +1,238 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/cord.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace text_classifier {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::EqualsProto;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::kMediaPipeTasksPayload;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
using ::testing::proto::Approximately;
|
||||
using ::testing::proto::IgnoringRepeatedFieldOrdering;
|
||||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr float kEpsilon = 0.001;
|
||||
constexpr int kMaxSeqLen = 128;
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
|
||||
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
|
||||
constexpr char kInvalidModelPath[] = "i/do/not/exist.tflite";
|
||||
constexpr char kTestRegexModelPath[] =
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||
constexpr char kStringToBoolModelPath[] =
|
||||
"test_model_text_classifier_bool_output.tflite";
|
||||
|
||||
std::string GetFullPath(absl::string_view file_name) {
|
||||
return JoinPath("./", kTestDataDirectory, file_name);
|
||||
}
|
||||
|
||||
class TextClassifierTest : public tflite_shims::testing::Test {};
|
||||
|
||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, CreateFailsWithMissingBaseOptions) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
StatusOr<std::unique_ptr<TextClassifier>> classifier =
|
||||
TextClassifier::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
classifier.status().message(),
|
||||
HasSubstr("ExternalFile must specify at least one of 'file_content', "
|
||||
"'file_name', 'file_pointer_meta' or 'file_descriptor_meta'."));
|
||||
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, CreateFailsWithMissingModel) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kInvalidModelPath);
|
||||
StatusOr<std::unique_ptr<TextClassifier>> classifier =
|
||||
TextClassifier::Create(std::move(options));
|
||||
|
||||
EXPECT_EQ(classifier.status().code(), absl::StatusCode::kNotFound);
|
||||
EXPECT_THAT(classifier.status().message(),
|
||||
HasSubstr("Unable to open file at"));
|
||||
EXPECT_THAT(classifier.status().GetPayload(kMediaPipeTasksPayload),
|
||||
Optional(absl::Cord(absl::StrCat(
|
||||
MediaPipeTasksStatus::kRunnerInitializationError))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||
MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithBert) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult negative_result,
|
||||
classifier->Classify("unflinchingly bleak and desperate"));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.956 }
|
||||
categories { category_name: "positive" score: 0.044 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("it's a charming and often affecting journey"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.0 }
|
||||
categories { category_name: "positive" score: 1.0 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult negative_result,
|
||||
classifier->Classify("What a waste of my time."));
|
||||
ASSERT_THAT(negative_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.813 }
|
||||
categories { category_name: "Positive" score: 0.187 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
ClassificationResult positive_result,
|
||||
classifier->Classify("This is the best movie I’ve seen in recent years. "
|
||||
"Strongly recommend it!"));
|
||||
ASSERT_THAT(positive_result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "Negative" score: 0.487 }
|
||||
categories { category_name: "Positive" score: 0.513 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
|
||||
options->base_options.op_resolver = CreateCustomResolver();
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify("hello"));
|
||||
ASSERT_THAT(result, Partially(IgnoringRepeatedFieldOrdering(EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { index: 1 score: 1 }
|
||||
categories { index: 0 score: 1 }
|
||||
categories { index: 2 score: 0 }
|
||||
}
|
||||
}
|
||||
)pb"))));
|
||||
}
|
||||
|
||||
TEST_F(TextClassifierTest, BertLongPositive) {
|
||||
std::stringstream ss_for_positive_review;
|
||||
ss_for_positive_review
|
||||
<< "it's a charming and often affecting journey and this is a long";
|
||||
for (int i = 0; i < kMaxSeqLen; ++i) {
|
||||
ss_for_positive_review << " long";
|
||||
}
|
||||
ss_for_positive_review << " movie review";
|
||||
auto options = std::make_unique<TextClassifierOptions>();
|
||||
options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
|
||||
TextClassifier::Create(std::move(options)));
|
||||
MP_ASSERT_OK_AND_ASSIGN(ClassificationResult result,
|
||||
classifier->Classify(ss_for_positive_review.str()));
|
||||
ASSERT_THAT(result,
|
||||
Partially(IgnoringRepeatedFieldOrdering(Approximately(
|
||||
EqualsProto(R"pb(
|
||||
classifications {
|
||||
entries {
|
||||
categories { category_name: "negative" score: 0.014 }
|
||||
categories { category_name: "positive" score: 0.986 }
|
||||
}
|
||||
}
|
||||
)pb"),
|
||||
kEpsilon))));
|
||||
MP_ASSERT_OK(classifier->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace text_classifier
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,131 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/portable_type_to_tflitetype.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::CreateStatusWithPayload;
|
||||
using ::tflite::GetInput;
|
||||
using ::tflite::GetOutput;
|
||||
using ::tflite::GetString;
|
||||
using ::tflite::StringRef;
|
||||
|
||||
constexpr absl::string_view kInputStr = "hello";
|
||||
constexpr bool kBooleanData[] = {true, true, false};
|
||||
constexpr size_t kBooleanDataSize = std::size(kBooleanData);
|
||||
|
||||
// Checks and returns type of a tensor, fails if tensor type is not T.
|
||||
template <typename T>
|
||||
absl::StatusOr<T*> AssertAndReturnTypedTensor(const TfLiteTensor* tensor) {
|
||||
if (!tensor->data.raw) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInternal,
|
||||
absl::StrFormat("Tensor (%s) has no raw data.", tensor->name));
|
||||
}
|
||||
|
||||
// Checks if data type of tensor is T and returns the pointer casted to T if
|
||||
// applicable, returns nullptr if tensor type is not T.
|
||||
// See type_to_tflitetype.h for a mapping from plain C++ type to TfLiteType.
|
||||
if (tensor->type == tflite::typeToTfLiteType<T>()) {
|
||||
return reinterpret_cast<T*>(tensor->data.raw);
|
||||
}
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInternal,
|
||||
absl::StrFormat("Type mismatch for tensor %s. Required %d, got %d.",
|
||||
tensor->name, tflite::typeToTfLiteType<T>(),
|
||||
tensor->bytes));
|
||||
}
|
||||
|
||||
// Populates tensor with array of data, fails if data type doesn't match tensor
|
||||
// type or they don't have the same number of elements.
|
||||
template <typename T, typename = std::enable_if_t<
|
||||
std::negation_v<std::is_same<T, std::string>>>>
|
||||
absl::Status PopulateTensor(const T* data, int num_elements,
|
||||
TfLiteTensor* tensor) {
|
||||
ASSIGN_OR_RETURN(T * v, AssertAndReturnTypedTensor<T>(tensor));
|
||||
size_t bytes = num_elements * sizeof(T);
|
||||
if (tensor->bytes != bytes) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInternal,
|
||||
absl::StrFormat("tensor->bytes (%d) != bytes (%d)", tensor->bytes,
|
||||
bytes));
|
||||
}
|
||||
std::memcpy(v, data, bytes);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
TfLiteStatus PrepareStringToBool(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TfLiteIntArray* dims = TfLiteIntArrayCreate(1);
|
||||
dims->data[0] = kBooleanDataSize;
|
||||
return context->ResizeTensor(context, output, dims);
|
||||
}
|
||||
|
||||
TfLiteStatus InvokeStringToBool(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input_tensor = GetInput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, input_tensor != nullptr);
|
||||
StringRef input_str_ref = GetString(input_tensor, 0);
|
||||
std::string input_str(input_str_ref.str, input_str_ref.len);
|
||||
if (input_str != kInputStr) {
|
||||
return kTfLiteError;
|
||||
}
|
||||
TfLiteTensor* output = GetOutput(context, node, 0);
|
||||
TF_LITE_ENSURE(context, output != nullptr);
|
||||
TF_LITE_ENSURE(context, PopulateTensor(kBooleanData, 3, output).ok());
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// This custom op takes a string tensor in and outputs a bool tensor with
|
||||
// value{true, true, false}, it's used to mimic a real text classification model
|
||||
// which classifies a string into scores of different categories.
|
||||
TfLiteRegistration* RegisterStringToBool() {
|
||||
// Dummy implementation of custom OP
|
||||
// This op takes string as input and outputs bool[]
|
||||
static TfLiteRegistration r = {/* init= */ nullptr, /* free= */ nullptr,
|
||||
/* prepare= */ PrepareStringToBool,
|
||||
/* invoke= */ InvokeStringToBool};
|
||||
return &r;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver() {
|
||||
tflite::MutableOpResolver resolver;
|
||||
resolver.AddCustom("CUSTOM_OP_STRING_TO_BOOLS", RegisterStringToBool());
|
||||
return std::make_unique<tflite::MutableOpResolver>(resolver);
|
||||
}
|
||||
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,35 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace text {
|
||||
|
||||
// Create a custom MutableOpResolver to provide custom OP implementations to
|
||||
// mimic classification behavior.
|
||||
std::unique_ptr<tflite::MutableOpResolver> CreateCustomResolver();
|
||||
|
||||
} // namespace text
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_TEXT_CLASSIFIER_TEXT_CLASSIFIER_TEST_UTILS_H_
|
|
@ -56,6 +56,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
|
@ -91,6 +92,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
|
@ -123,6 +125,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers:gesture_recognition_result",
|
||||
|
|
|
@ -69,6 +69,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
@ -86,6 +87,7 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -14,6 +14,7 @@ limitations under the License.
|
|||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -26,6 +27,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/tasks/cc/vision/gesture_recognizer/calculators/landmarks_to_matrix_calculator.pb.h"
|
||||
|
||||
|
@ -38,6 +40,7 @@ namespace {
|
|||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX";
|
||||
constexpr int kFeaturesPerLandmark = 3;
|
||||
|
||||
|
@ -62,6 +65,25 @@ absl::StatusOr<LandmarkListT> NormalizeLandmarkAspectRatio(
|
|||
return normalized_landmarks;
|
||||
}
|
||||
|
||||
template <class LandmarkListT>
|
||||
absl::StatusOr<LandmarkListT> RotateLandmarks(const LandmarkListT& landmarks,
|
||||
float rotation) {
|
||||
float cos = std::cos(rotation);
|
||||
// Negate because Y-axis points down and not up.
|
||||
float sin = std::sin(-rotation);
|
||||
LandmarkListT rotated_landmarks;
|
||||
for (int i = 0; i < landmarks.landmark_size(); ++i) {
|
||||
const auto& old_landmark = landmarks.landmark(i);
|
||||
float x = old_landmark.x() - 0.5;
|
||||
float y = old_landmark.y() - 0.5;
|
||||
auto* new_landmark = rotated_landmarks.add_landmark();
|
||||
new_landmark->set_x(x * cos - y * sin + 0.5);
|
||||
new_landmark->set_y(y * cos + x * sin + 0.5);
|
||||
new_landmark->set_z(old_landmark.z());
|
||||
}
|
||||
return rotated_landmarks;
|
||||
}
|
||||
|
||||
template <class LandmarkListT>
|
||||
absl::StatusOr<LandmarkListT> NormalizeObject(const LandmarkListT& landmarks,
|
||||
int origin_offset) {
|
||||
|
@ -134,6 +156,13 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) {
|
|||
NormalizeLandmarkAspectRatio(landmarks, width, height));
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag(kNormRectTag)) {
|
||||
RET_CHECK(!cc->Inputs().Tag(kNormRectTag).IsEmpty());
|
||||
const auto rotation =
|
||||
cc->Inputs().Tag(kNormRectTag).Get<NormalizedRect>().rotation();
|
||||
ASSIGN_OR_RETURN(landmarks, RotateLandmarks(landmarks, rotation));
|
||||
}
|
||||
|
||||
const auto& options = cc->Options<LandmarksToMatrixCalculatorOptions>();
|
||||
if (options.object_normalization()) {
|
||||
ASSIGN_OR_RETURN(
|
||||
|
@ -163,6 +192,8 @@ absl::Status ProcessLandmarks(LandmarkListT landmarks, CalculatorContext* cc) {
|
|||
// WORLD_LANDMARKS - World 3d landmarks of one object. Use *either*
|
||||
// LANDMARKS or WORLD_LANDMARKS.
|
||||
// IMAGE_SIZE - (width, height) of the image
|
||||
// NORM_RECT - Optional NormalizedRect object whose 'rotation' field is used
|
||||
// to rotate the landmarks.
|
||||
// Output:
|
||||
// LANDMARKS_MATRIX - Matrix for the landmarks.
|
||||
//
|
||||
|
@ -185,6 +216,7 @@ class LandmarksToMatrixCalculator : public CalculatorBase {
|
|||
cc->Inputs().Tag(kLandmarksTag).Set<NormalizedLandmarkList>().Optional();
|
||||
cc->Inputs().Tag(kWorldLandmarksTag).Set<LandmarkList>().Optional();
|
||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>().Optional();
|
||||
cc->Inputs().Tag(kNormRectTag).Set<NormalizedRect>().Optional();
|
||||
cc->Outputs().Tag(kLandmarksMatrixTag).Set<Matrix>();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
@ -23,6 +24,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
@ -35,6 +37,7 @@ constexpr char kLandmarksTag[] = "LANDMARKS";
|
|||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
|
||||
template <class LandmarkListT>
|
||||
LandmarkListT BuildPseudoLandmarks(int num_landmarks, int offset = 0) {
|
||||
|
@ -54,6 +57,7 @@ struct Landmarks2dToMatrixCalculatorTestCase {
|
|||
int object_normalization_origin_offset = -1;
|
||||
float expected_cell_0_2;
|
||||
float expected_cell_1_5;
|
||||
float rotation;
|
||||
};
|
||||
|
||||
using Landmarks2dToMatrixCalculatorTest =
|
||||
|
@ -68,6 +72,7 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) {
|
|||
calculator: "LandmarksToMatrixCalculator"
|
||||
input_stream: "LANDMARKS:landmarks"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
input_stream: "NORM_RECT:norm_rect"
|
||||
output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
|
||||
options {
|
||||
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
|
||||
|
@ -91,6 +96,11 @@ TEST_P(Landmarks2dToMatrixCalculatorTest, OutputsCorrectResult) {
|
|||
runner.MutableInputs()
|
||||
->Tag(kImageSizeTag)
|
||||
.packets.push_back(Adopt(image_size.release()).At(Timestamp(0)));
|
||||
auto norm_rect = std::make_unique<NormalizedRect>();
|
||||
norm_rect->set_rotation(test_case.rotation);
|
||||
runner.MutableInputs()
|
||||
->Tag(kNormRectTag)
|
||||
.packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
|
||||
|
@ -109,12 +119,20 @@ INSTANTIATE_TEST_CASE_P(
|
|||
.base_offset = 0,
|
||||
.object_normalization_origin_offset = 0,
|
||||
.expected_cell_0_2 = 0.1f,
|
||||
.expected_cell_1_5 = 0.1875f},
|
||||
.expected_cell_1_5 = 0.1875f,
|
||||
.rotation = 0},
|
||||
{.test_name = "TestWithOffset21",
|
||||
.base_offset = 21,
|
||||
.object_normalization_origin_offset = 0,
|
||||
.expected_cell_0_2 = 0.1f,
|
||||
.expected_cell_1_5 = 0.1875f}}),
|
||||
.expected_cell_1_5 = 0.1875f,
|
||||
.rotation = 0},
|
||||
{.test_name = "TestWithRotation",
|
||||
.base_offset = 0,
|
||||
.object_normalization_origin_offset = 0,
|
||||
.expected_cell_0_2 = 0.075f,
|
||||
.expected_cell_1_5 = -0.25f,
|
||||
.rotation = M_PI / 2.0}}),
|
||||
[](const testing::TestParamInfo<
|
||||
Landmarks2dToMatrixCalculatorTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
|
@ -126,6 +144,7 @@ struct LandmarksWorld3dToMatrixCalculatorTestCase {
|
|||
int object_normalization_origin_offset = -1;
|
||||
float expected_cell_0_2;
|
||||
float expected_cell_1_5;
|
||||
float rotation;
|
||||
};
|
||||
|
||||
using LandmarksWorld3dToMatrixCalculatorTest =
|
||||
|
@ -140,6 +159,7 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) {
|
|||
calculator: "LandmarksToMatrixCalculator"
|
||||
input_stream: "WORLD_LANDMARKS:landmarks"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
input_stream: "NORM_RECT:norm_rect"
|
||||
output_stream: "LANDMARKS_MATRIX:landmarks_matrix"
|
||||
options {
|
||||
[mediapipe.LandmarksToMatrixCalculatorOptions.ext] {
|
||||
|
@ -162,6 +182,11 @@ TEST_P(LandmarksWorld3dToMatrixCalculatorTest, OutputsCorrectResult) {
|
|||
runner.MutableInputs()
|
||||
->Tag(kImageSizeTag)
|
||||
.packets.push_back(Adopt(image_size.release()).At(Timestamp(0)));
|
||||
auto norm_rect = std::make_unique<NormalizedRect>();
|
||||
norm_rect->set_rotation(test_case.rotation);
|
||||
runner.MutableInputs()
|
||||
->Tag(kNormRectTag)
|
||||
.packets.push_back(Adopt(norm_rect.release()).At(Timestamp(0)));
|
||||
|
||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||
|
||||
|
@ -180,17 +205,26 @@ INSTANTIATE_TEST_CASE_P(
|
|||
.base_offset = 0,
|
||||
.object_normalization_origin_offset = 0,
|
||||
.expected_cell_0_2 = 0.1f,
|
||||
.expected_cell_1_5 = 0.25},
|
||||
.expected_cell_1_5 = 0.25,
|
||||
.rotation = 0},
|
||||
{.test_name = "TestWithOffset21",
|
||||
.base_offset = 21,
|
||||
.object_normalization_origin_offset = 0,
|
||||
.expected_cell_0_2 = 0.1f,
|
||||
.expected_cell_1_5 = 0.25},
|
||||
.expected_cell_1_5 = 0.25,
|
||||
.rotation = 0},
|
||||
{.test_name = "NoObjectNormalization",
|
||||
.base_offset = 0,
|
||||
.object_normalization_origin_offset = -1,
|
||||
.expected_cell_0_2 = 0.021f,
|
||||
.expected_cell_1_5 = 0.052f}}),
|
||||
.expected_cell_1_5 = 0.052f,
|
||||
.rotation = 0},
|
||||
{.test_name = "TestWithRotation",
|
||||
.base_offset = 0,
|
||||
.object_normalization_origin_offset = 0,
|
||||
.expected_cell_0_2 = 0.1f,
|
||||
.expected_cell_1_5 = -0.25f,
|
||||
.rotation = M_PI / 2.0}}),
|
||||
[](const testing::TestParamInfo<
|
||||
LandmarksWorld3dToMatrixCalculatorTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
|
@ -27,6 +28,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
|
@ -62,6 +64,8 @@ constexpr char kHandGestureSubgraphTypeName[] =
|
|||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageInStreamName[] = "image_in";
|
||||
constexpr char kImageOutStreamName[] = "image_out";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormRectStreamName[] = "norm_rect_in";
|
||||
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
|
||||
constexpr char kHandGesturesStreamName[] = "hand_gestures";
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
|
@ -72,6 +76,31 @@ constexpr char kHandWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
|||
constexpr char kHandWorldLandmarksStreamName[] = "world_landmarks";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
// Returns a NormalizedRect filling the whole image. If input is present, its
|
||||
// rotation is set in the returned NormalizedRect and a check is performed to
|
||||
// make sure no region-of-interest was provided. Otherwise, rotation is set to
|
||||
// 0.
|
||||
absl::StatusOr<NormalizedRect> FillNormalizedRect(
|
||||
std::optional<NormalizedRect> normalized_rect) {
|
||||
NormalizedRect result;
|
||||
if (normalized_rect.has_value()) {
|
||||
result = *normalized_rect;
|
||||
}
|
||||
bool has_coordinates = result.has_x_center() || result.has_y_center() ||
|
||||
result.has_width() || result.has_height();
|
||||
if (has_coordinates) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GestureRecognizer does not support region-of-interest.",
|
||||
MediaPipeTasksStatus::kInvalidArgumentError);
|
||||
}
|
||||
result.set_x_center(0.5);
|
||||
result.set_y_center(0.5);
|
||||
result.set_width(1);
|
||||
result.set_height(1);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.GestureRecognizerGraph". If the task is running
|
||||
// in the live stream mode, a "FlowLimiterCalculator" will be added to limit the
|
||||
|
@ -83,6 +112,7 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
auto& subgraph = graph.AddNode(kHandGestureSubgraphTypeName);
|
||||
subgraph.GetOptions<GestureRecognizerGraphOptionsProto>().Swap(options.get());
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectStreamName);
|
||||
subgraph.Out(kHandGesturesTag).SetName(kHandGesturesStreamName) >>
|
||||
graph.Out(kHandGesturesTag);
|
||||
subgraph.Out(kHandednessTag).SetName(kHandednessStreamName) >>
|
||||
|
@ -93,10 +123,11 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
graph.Out(kHandWorldLandmarksTag);
|
||||
subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag);
|
||||
if (enable_flow_limiting) {
|
||||
return tasks::core::AddFlowLimiterCalculator(graph, subgraph, {kImageTag},
|
||||
kHandGesturesTag);
|
||||
return tasks::core::AddFlowLimiterCalculator(
|
||||
graph, subgraph, {kImageTag, kNormRectTag}, kHandGesturesTag);
|
||||
}
|
||||
graph.In(kImageTag) >> subgraph.In(kImageTag);
|
||||
graph.In(kNormRectTag) >> subgraph.In(kNormRectTag);
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -216,16 +247,22 @@ absl::StatusOr<std::unique_ptr<GestureRecognizer>> GestureRecognizer::Create(
|
|||
}
|
||||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
||||
mediapipe::Image image) {
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
"GPU input images are currently not supported.",
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto output_packets,
|
||||
ProcessImageData({{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))}}));
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessImageData(
|
||||
{{kImageInStreamName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))}}));
|
||||
if (output_packets[kHandGesturesStreamName].IsEmpty()) {
|
||||
return {{{}, {}, {}, {}}};
|
||||
}
|
||||
|
@ -245,18 +282,24 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::Recognize(
|
|||
}
|
||||
|
||||
absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms) {
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_packets,
|
||||
ProcessVideoData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
|
||||
if (output_packets[kHandGesturesStreamName].IsEmpty()) {
|
||||
return {{{}, {}, {}, {}}};
|
||||
|
@ -276,17 +319,23 @@ absl::StatusOr<GestureRecognitionResult> GestureRecognizer::RecognizeForVideo(
|
|||
};
|
||||
}
|
||||
|
||||
absl::Status GestureRecognizer::RecognizeAsync(mediapipe::Image image,
|
||||
int64 timestamp_ms) {
|
||||
absl::Status GestureRecognizer::RecognizeAsync(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options) {
|
||||
if (image.UsesGpu()) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
absl::StrCat("GPU input images are currently not supported."),
|
||||
MediaPipeTasksStatus::kRunnerUnexpectedInputError);
|
||||
}
|
||||
ASSIGN_OR_RETURN(NormalizedRect norm_rect,
|
||||
FillNormalizedRect(image_processing_options));
|
||||
return SendLiveStreamData(
|
||||
{{kImageInStreamName,
|
||||
MakePacket<Image>(std::move(image))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))},
|
||||
{kNormRectStreamName,
|
||||
MakePacket<NormalizedRect>(std::move(norm_rect))
|
||||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
|
|
|
@ -17,11 +17,13 @@ limitations under the License.
|
|||
#define MEDIAPIPE_TASKS_CC_VISION_GESTURE_RECOGNIZRER_GESTURE_RECOGNIZER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/gesture_recognition_result.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
|
@ -57,7 +59,7 @@ struct GestureRecognizerOptions {
|
|||
int num_hands = 1;
|
||||
|
||||
// The minimum confidence score for the hand detection to be considered
|
||||
// successfully.
|
||||
// successful.
|
||||
float min_hand_detection_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score of hand presence score in the hand landmark
|
||||
|
@ -65,11 +67,11 @@ struct GestureRecognizerOptions {
|
|||
float min_hand_presence_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score for the hand tracking to be considered
|
||||
// successfully.
|
||||
// successful.
|
||||
float min_tracking_confidence = 0.5;
|
||||
|
||||
// The minimum confidence score for the gestures to be considered
|
||||
// successfully. If < 0, the gesture confidence thresholds in the model
|
||||
// successful. If < 0, the gesture confidence thresholds in the model
|
||||
// metadata are used.
|
||||
// TODO Note this option is subject to change, after scoring
|
||||
// merging calculator is implemented.
|
||||
|
@ -93,6 +95,13 @@ struct GestureRecognizerOptions {
|
|||
// Inputs:
|
||||
// Image
|
||||
// - The image that gesture recognition runs on.
|
||||
// std::optional<NormalizedRect>
|
||||
// - If provided, can be used to specify the rotation to apply to the image
|
||||
// before performing gesture recognition, by setting its 'rotation' field
|
||||
// in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note
|
||||
// that specifying a region-of-interest using the 'x_center', 'y_center',
|
||||
// 'width' and 'height' fields is NOT supported and will result in an
|
||||
// invalid argument error being returned.
|
||||
// Outputs:
|
||||
// GestureRecognitionResult
|
||||
// - The hand gesture recognition results.
|
||||
|
@ -122,12 +131,23 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
//
|
||||
// image - mediapipe::Image
|
||||
// Image to perform hand gesture recognition on.
|
||||
// imageProcessingOptions - std::optional<NormalizedRect>
|
||||
// If provided, can be used to specify the rotation to apply to the image
|
||||
// before performing classification, by setting its 'rotation' field in
|
||||
// radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). Note that
|
||||
// specifying a region-of-interest using the 'x_center', 'y_center', 'width'
|
||||
// and 'height' fields is NOT supported and will result in an invalid
|
||||
// argument error being returned.
|
||||
//
|
||||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: Describes how the input image will be preprocessed
|
||||
// after the yuv support is implemented.
|
||||
// TODO: use an ImageProcessingOptions struct instead of
|
||||
// NormalizedRect.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult> Recognize(
|
||||
Image image);
|
||||
Image image,
|
||||
std::optional<mediapipe::NormalizedRect> image_processing_options =
|
||||
std::nullopt);
|
||||
|
||||
// Performs gesture recognition on the provided video frame.
|
||||
// Only use this method when the GestureRecognizer is created with the video
|
||||
|
@ -137,7 +157,9 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<components::containers::GestureRecognitionResult>
|
||||
RecognizeForVideo(Image image, int64 timestamp_ms);
|
||||
RecognizeForVideo(Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Sends live image data to perform gesture recognition, and the results will
|
||||
// be available via the "result_callback" provided in the
|
||||
|
@ -157,7 +179,9 @@ class GestureRecognizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// longer be valid when the callback returns. To access the image data
|
||||
// outside of the callback, callers need to make a copy of the image.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status RecognizeAsync(Image image, int64 timestamp_ms);
|
||||
absl::Status RecognizeAsync(Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect>
|
||||
image_processing_options = std::nullopt);
|
||||
|
||||
// Shuts down the GestureRecognizer when all works are done.
|
||||
absl::Status Close() { return runner_->Close(); }
|
||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
|
@ -53,6 +54,7 @@ using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
|||
HandLandmarkerGraphOptions;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kHandednessTag[] = "HANDEDNESS";
|
||||
|
@ -76,6 +78,9 @@ struct GestureRecognizerOutputs {
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform hand gesture recognition on.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// Describes image rotation and region of image to perform landmarks
|
||||
// detection on.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - std::vector<ClassificationList>
|
||||
|
@ -93,13 +98,15 @@ struct GestureRecognizerOutputs {
|
|||
// IMAGE - mediapipe::Image
|
||||
// The image that gesture recognizer runs on and has the pixel data stored
|
||||
// on the target storage (CPU vs GPU).
|
||||
//
|
||||
// All returned coordinates are in the unrotated and uncropped input image
|
||||
// coordinates system.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator:
|
||||
// "mediapipe.tasks.vision.gesture_recognizer.GestureRecognizerGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "HAND_GESTURES:hand_gestures"
|
||||
// output_stream: "LANDMARKS:hand_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
|
||||
|
@ -132,7 +139,8 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
|
||||
BuildGestureRecognizerGraph(
|
||||
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
hand_gesture_recognition_output.gesture >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
|
||||
hand_gesture_recognition_output.handedness >>
|
||||
|
@ -148,7 +156,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
private:
|
||||
absl::StatusOr<GestureRecognizerOutputs> BuildGestureRecognizerGraph(
|
||||
GestureRecognizerGraphOptions& graph_options, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
auto& image_property = graph.AddNode("ImagePropertiesCalculator");
|
||||
image_in >> image_property.In("IMAGE");
|
||||
auto image_size = image_property.Out("SIZE");
|
||||
|
@ -162,6 +170,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
graph_options.mutable_hand_landmarker_graph_options());
|
||||
|
||||
image_in >> hand_landmarker_graph.In(kImageTag);
|
||||
norm_rect_in >> hand_landmarker_graph.In(kNormRectTag);
|
||||
auto hand_landmarks =
|
||||
hand_landmarker_graph[Output<std::vector<NormalizedLandmarkList>>(
|
||||
kLandmarksTag)];
|
||||
|
@ -187,6 +196,7 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
hand_world_landmarks >> hand_gesture_subgraph.In(kWorldLandmarksTag);
|
||||
handedness >> hand_gesture_subgraph.In(kHandednessTag);
|
||||
image_size >> hand_gesture_subgraph.In(kImageSizeTag);
|
||||
norm_rect_in >> hand_gesture_subgraph.In(kNormRectTag);
|
||||
hand_landmarks_id >> hand_gesture_subgraph.In(kHandTrackingIdsTag);
|
||||
auto hand_gestures =
|
||||
hand_gesture_subgraph[Output<std::vector<ClassificationList>>(
|
||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
|
@ -57,6 +58,7 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
|
|||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
|
||||
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
|
||||
constexpr char kLandmarksMatrixTag[] = "LANDMARKS_MATRIX";
|
||||
|
@ -92,6 +94,9 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
// Detected hand landmarks in world coordinates.
|
||||
// IMAGE_SIZE - std::pair<int, int>
|
||||
// The size of image from which the landmarks detected from.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// NormalizedRect whose 'rotation' field is used to rotate the
|
||||
// landmarks before processing them.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - ClassificationList
|
||||
|
@ -106,6 +111,7 @@ Source<std::vector<Tensor>> ConvertMatrixToTensor(Source<Matrix> matrix,
|
|||
// input_stream: "LANDMARKS:landmarks"
|
||||
// input_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "HAND_GESTURES:hand_gestures"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.gesture_recognizer.proto.HandGestureRecognizerGraphOptions.ext]
|
||||
|
@ -133,7 +139,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
graph[Input<ClassificationList>(kHandednessTag)],
|
||||
graph[Input<NormalizedLandmarkList>(kLandmarksTag)],
|
||||
graph[Input<LandmarkList>(kWorldLandmarksTag)],
|
||||
graph[Input<std::pair<int, int>>(kImageSizeTag)], graph));
|
||||
graph[Input<std::pair<int, int>>(kImageSizeTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
hand_gestures >> graph[Output<ClassificationList>(kHandGesturesTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
@ -145,7 +152,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
Source<ClassificationList> handedness,
|
||||
Source<NormalizedLandmarkList> hand_landmarks,
|
||||
Source<LandmarkList> hand_world_landmarks,
|
||||
Source<std::pair<int, int>> image_size, Graph& graph) {
|
||||
Source<std::pair<int, int>> image_size, Source<NormalizedRect> norm_rect,
|
||||
Graph& graph) {
|
||||
// Converts the ClassificationList to a matrix.
|
||||
auto& handedness_to_matrix = graph.AddNode("HandednessToMatrixCalculator");
|
||||
handedness >> handedness_to_matrix.In(kHandednessTag);
|
||||
|
@ -166,6 +174,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
landmarks_options;
|
||||
hand_landmarks >> hand_landmarks_to_matrix.In(kLandmarksTag);
|
||||
image_size >> hand_landmarks_to_matrix.In(kImageSizeTag);
|
||||
norm_rect >> hand_landmarks_to_matrix.In(kNormRectTag);
|
||||
auto hand_landmarks_matrix =
|
||||
hand_landmarks_to_matrix[Output<Matrix>(kLandmarksMatrixTag)];
|
||||
|
||||
|
@ -181,6 +190,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
hand_world_landmarks >>
|
||||
hand_world_landmarks_to_matrix.In(kWorldLandmarksTag);
|
||||
image_size >> hand_world_landmarks_to_matrix.In(kImageSizeTag);
|
||||
norm_rect >> hand_world_landmarks_to_matrix.In(kNormRectTag);
|
||||
auto hand_world_landmarks_matrix =
|
||||
hand_world_landmarks_to_matrix[Output<Matrix>(kLandmarksMatrixTag)];
|
||||
|
||||
|
@ -239,6 +249,9 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// A vector hand landmarks in world coordinates.
|
||||
// IMAGE_SIZE - std::pair<int, int>
|
||||
// The size of image from which the landmarks detected from.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// NormalizedRect whose 'rotation' field is used to rotate the
|
||||
// landmarks before processing them.
|
||||
// HAND_TRACKING_IDS - std::vector<int>
|
||||
// A vector of the tracking ids of the hands. The tracking id is the vector
|
||||
// index corresponding to the same hand if the graph runs multiple times.
|
||||
|
@ -257,6 +270,7 @@ REGISTER_MEDIAPIPE_GRAPH(
|
|||
// input_stream: "LANDMARKS:landmarks"
|
||||
// input_stream: "WORLD_LANDMARKS:world_landmarks"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// input_stream: "HAND_TRACKING_IDS:hand_tracking_ids"
|
||||
// output_stream: "HAND_GESTURES:hand_gestures"
|
||||
// options {
|
||||
|
@ -283,6 +297,7 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
graph[Input<std::vector<NormalizedLandmarkList>>(kLandmarksTag)],
|
||||
graph[Input<std::vector<LandmarkList>>(kWorldLandmarksTag)],
|
||||
graph[Input<std::pair<int, int>>(kImageSizeTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)],
|
||||
graph[Input<std::vector<int>>(kHandTrackingIdsTag)], graph));
|
||||
multi_hand_gestures >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
|
||||
|
@ -296,18 +311,20 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
Source<std::vector<ClassificationList>> multi_handedness,
|
||||
Source<std::vector<NormalizedLandmarkList>> multi_hand_landmarks,
|
||||
Source<std::vector<LandmarkList>> multi_hand_world_landmarks,
|
||||
Source<std::pair<int, int>> image_size,
|
||||
Source<std::pair<int, int>> image_size, Source<NormalizedRect> norm_rect,
|
||||
Source<std::vector<int>> multi_hand_tracking_ids, Graph& graph) {
|
||||
auto& begin_loop_int = graph.AddNode("BeginLoopIntCalculator");
|
||||
image_size >> begin_loop_int.In(kCloneTag)[0];
|
||||
multi_handedness >> begin_loop_int.In(kCloneTag)[1];
|
||||
multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[2];
|
||||
multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[3];
|
||||
norm_rect >> begin_loop_int.In(kCloneTag)[1];
|
||||
multi_handedness >> begin_loop_int.In(kCloneTag)[2];
|
||||
multi_hand_landmarks >> begin_loop_int.In(kCloneTag)[3];
|
||||
multi_hand_world_landmarks >> begin_loop_int.In(kCloneTag)[4];
|
||||
multi_hand_tracking_ids >> begin_loop_int.In(kIterableTag);
|
||||
auto image_size_clone = begin_loop_int.Out(kCloneTag)[0];
|
||||
auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[1];
|
||||
auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[2];
|
||||
auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[3];
|
||||
auto norm_rect_clone = begin_loop_int.Out(kCloneTag)[1];
|
||||
auto multi_handedness_clone = begin_loop_int.Out(kCloneTag)[2];
|
||||
auto multi_hand_landmarks_clone = begin_loop_int.Out(kCloneTag)[3];
|
||||
auto multi_hand_world_landmarks_clone = begin_loop_int.Out(kCloneTag)[4];
|
||||
auto hand_tracking_id = begin_loop_int.Out(kItemTag);
|
||||
auto batch_end = begin_loop_int.Out(kBatchEndTag);
|
||||
|
||||
|
@ -341,6 +358,7 @@ class MultipleHandGestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
hand_world_landmarks >>
|
||||
hand_gesture_recognizer_graph.In(kWorldLandmarksTag);
|
||||
image_size_clone >> hand_gesture_recognizer_graph.In(kImageSizeTag);
|
||||
norm_rect_clone >> hand_gesture_recognizer_graph.In(kNormRectTag);
|
||||
auto hand_gestures = hand_gesture_recognizer_graph.Out(kHandGesturesTag);
|
||||
|
||||
auto& end_loop_classification_lists =
|
||||
|
|
|
@ -32,7 +32,7 @@ cc_library(
|
|||
"//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
|
||||
"//mediapipe/calculators/util:detection_label_id_to_text_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detection_letterbox_removal_calculator",
|
||||
"//mediapipe/calculators/util:detection_projection_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:non_max_suppression_calculator",
|
||||
|
|
|
@ -58,6 +58,7 @@ using ::mediapipe::tasks::vision::hand_detector::proto::
|
|||
HandDetectorGraphOptions;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||
constexpr char kHandRectsTag[] = "HAND_RECTS";
|
||||
constexpr char kPalmRectsTag[] = "PALM_RECTS";
|
||||
|
@ -148,6 +149,9 @@ void ConfigureRectTransformationCalculator(
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform detection on.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// Describes image rotation and region of image to perform detection
|
||||
// on.
|
||||
//
|
||||
// Outputs:
|
||||
// PALM_DETECTIONS - std::vector<Detection>
|
||||
|
@ -159,11 +163,14 @@ void ConfigureRectTransformationCalculator(
|
|||
// IMAGE - Image
|
||||
// The input image that the hand detector runs on and has the pixel data
|
||||
// stored on the target storage (CPU vs GPU).
|
||||
// All returned coordinates are in the unrotated and uncropped input image
|
||||
// coordinates system.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.hand_detector.HandDetectorGraph"
|
||||
// input_stream: "IMAGE:image"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "PALM_DETECTIONS:palm_detections"
|
||||
// output_stream: "HAND_RECTS:hand_rects_from_palm_detections"
|
||||
// output_stream: "PALM_RECTS:palm_rects"
|
||||
|
@ -189,11 +196,11 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<HandDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_detection_outs,
|
||||
BuildHandDetectionSubgraph(sc->Options<HandDetectorGraphOptions>(),
|
||||
*model_resources,
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
ASSIGN_OR_RETURN(auto hand_detection_outs,
|
||||
BuildHandDetectionSubgraph(
|
||||
sc->Options<HandDetectorGraphOptions>(),
|
||||
*model_resources, graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
hand_detection_outs.palm_detections >>
|
||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||
hand_detection_outs.hand_rects >>
|
||||
|
@ -216,7 +223,7 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<HandDetectionOuts> BuildHandDetectionSubgraph(
|
||||
const HandDetectorGraphOptions& subgraph_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
// Add image preprocessing subgraph. The model expects aspect ratio
|
||||
// unchanged.
|
||||
auto& preprocessing =
|
||||
|
@ -233,8 +240,9 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
&preprocessing
|
||||
.GetOptions<tasks::components::ImagePreprocessingOptions>()));
|
||||
image_in >> preprocessing.In("IMAGE");
|
||||
norm_rect_in >> preprocessing.In("NORM_RECT");
|
||||
auto preprocessed_tensors = preprocessing.Out("TENSORS");
|
||||
auto letterbox_padding = preprocessing.Out("LETTERBOX_PADDING");
|
||||
auto matrix = preprocessing.Out("MATRIX");
|
||||
auto image_size = preprocessing.Out("IMAGE_SIZE");
|
||||
|
||||
// Adds SSD palm detection model.
|
||||
|
@ -278,17 +286,12 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
nms_detections >> detection_label_id_to_text.In("");
|
||||
auto detections_with_text = detection_label_id_to_text.Out("");
|
||||
|
||||
// Adjusts detection locations (already normalized to [0.f, 1.f]) on the
|
||||
// letterboxed image (after image transformation with the FIT scale mode) to
|
||||
// the corresponding locations on the same image with the letterbox removed
|
||||
// (the input image to the graph before image transformation).
|
||||
auto& detection_letterbox_removal =
|
||||
graph.AddNode("DetectionLetterboxRemovalCalculator");
|
||||
detections_with_text >> detection_letterbox_removal.In("DETECTIONS");
|
||||
letterbox_padding >> detection_letterbox_removal.In("LETTERBOX_PADDING");
|
||||
// Projects detections back into the input image coordinates system.
|
||||
auto& detection_projection = graph.AddNode("DetectionProjectionCalculator");
|
||||
detections_with_text >> detection_projection.In("DETECTIONS");
|
||||
matrix >> detection_projection.In("PROJECTION_MATRIX");
|
||||
auto palm_detections =
|
||||
detection_letterbox_removal[Output<std::vector<Detection>>(
|
||||
"DETECTIONS")];
|
||||
detection_projection[Output<std::vector<Detection>>("DETECTIONS")];
|
||||
|
||||
// Converts each palm detection into a rectangle (normalized by image size)
|
||||
// that encloses the palm and is rotated such that the line connecting
|
||||
|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
@ -75,13 +76,18 @@ using ::testing::proto::Partially;
|
|||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite";
|
||||
constexpr char kTestRightHandsImage[] = "right_hands.jpg";
|
||||
constexpr char kTestRightHandsRotatedImage[] = "right_hands_rotated.jpg";
|
||||
constexpr char kTestModelResourcesTag[] = "test_model_resources";
|
||||
|
||||
constexpr char kOneHandResultFile[] = "hand_detector_result_one_hand.pbtxt";
|
||||
constexpr char kOneHandRotatedResultFile[] =
|
||||
"hand_detector_result_one_hand_rotated.pbtxt";
|
||||
constexpr char kTwoHandsResultFile[] = "hand_detector_result_two_hands.pbtxt";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageName[] = "image";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormRectName[] = "norm_rect";
|
||||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||
constexpr char kPalmDetectionsName[] = "palm_detections";
|
||||
constexpr char kHandRectsTag[] = "HAND_RECTS";
|
||||
|
@ -117,6 +123,8 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
|
|||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
hand_detection.In(kImageTag);
|
||||
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
|
||||
hand_detection.In(kNormRectTag);
|
||||
|
||||
hand_detection.Out(kPalmDetectionsTag).SetName(kPalmDetectionsName) >>
|
||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||
|
@ -142,6 +150,9 @@ struct TestParams {
|
|||
std::string hand_detection_model_name;
|
||||
// The filename of test image.
|
||||
std::string test_image_name;
|
||||
// The rotation to apply to the test image before processing, in radians
|
||||
// counter-clockwise.
|
||||
float rotation;
|
||||
// The number of maximum detected hands.
|
||||
int num_hands;
|
||||
// The expected hand detector result.
|
||||
|
@ -154,14 +165,22 @@ TEST_P(HandDetectionTest, DetectTwoHands) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
GetParam().test_image_name)));
|
||||
NormalizedRect input_norm_rect;
|
||||
input_norm_rect.set_rotation(GetParam().rotation);
|
||||
input_norm_rect.set_x_center(0.5);
|
||||
input_norm_rect.set_y_center(0.5);
|
||||
input_norm_rect.set_width(1.0);
|
||||
input_norm_rect.set_height(1.0);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(GetParam().hand_detection_model_name));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto task_runner, CreateTaskRunner(*model_resources, kPalmDetectionModel,
|
||||
GetParam().num_hands));
|
||||
auto output_packets =
|
||||
task_runner->Process({{kImageName, MakePacket<Image>(std::move(image))}});
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName,
|
||||
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
|
||||
MP_ASSERT_OK(output_packets);
|
||||
const std::vector<Detection>& palm_detections =
|
||||
(*output_packets)[kPalmDetectionsName].Get<std::vector<Detection>>();
|
||||
|
@ -188,15 +207,24 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
Values(TestParams{.test_name = "DetectOneHand",
|
||||
.hand_detection_model_name = kPalmDetectionModel,
|
||||
.test_image_name = kTestRightHandsImage,
|
||||
.rotation = 0,
|
||||
.num_hands = 1,
|
||||
.expected_result =
|
||||
GetExpectedHandDetectorResult(kOneHandResultFile)},
|
||||
TestParams{.test_name = "DetectTwoHands",
|
||||
.hand_detection_model_name = kPalmDetectionModel,
|
||||
.test_image_name = kTestRightHandsImage,
|
||||
.rotation = 0,
|
||||
.num_hands = 2,
|
||||
.expected_result =
|
||||
GetExpectedHandDetectorResult(kTwoHandsResultFile)}),
|
||||
GetExpectedHandDetectorResult(kTwoHandsResultFile)},
|
||||
TestParams{.test_name = "DetectOneHandWithRotation",
|
||||
.hand_detection_model_name = kPalmDetectionModel,
|
||||
.test_image_name = kTestRightHandsRotatedImage,
|
||||
.rotation = M_PI / 2.0f,
|
||||
.num_hands = 1,
|
||||
.expected_result = GetExpectedHandDetectorResult(
|
||||
kOneHandRotatedResultFile)}),
|
||||
[](const TestParamInfo<HandDetectionTest::ParamType>& info) {
|
||||
return info.param.test_name;
|
||||
});
|
||||
|
|
|
@ -91,10 +91,14 @@ cc_library(
|
|||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/utils:gate",
|
||||
"//mediapipe/tasks/cc/core:model_asset_bundle_resources",
|
||||
"//mediapipe/tasks/cc/core:model_resources_cache",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
"//mediapipe/tasks/cc/metadata/utils:zip_utils",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector:hand_detector_graph",
|
||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/hand_landmarker/calculators:hand_association_calculator",
|
||||
|
|
|
@ -57,7 +57,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:landmarks_detection",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
"//mediapipe/tasks/cc/vision/utils:landmarks_duplicates_finder",
|
||||
"//mediapipe/tasks/cc/vision/utils:landmarks_utils",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
|
|
|
@ -34,7 +34,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/landmarks_duplicates_finder.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/landmarks_utils.h"
|
||||
|
||||
|
@ -44,7 +44,7 @@ namespace {
|
|||
using ::mediapipe::api2::Input;
|
||||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::Bound;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
using ::mediapipe::tasks::vision::utils::CalculateIOU;
|
||||
using ::mediapipe::tasks::vision::utils::DuplicatesFinder;
|
||||
|
||||
|
@ -126,7 +126,7 @@ absl::StatusOr<float> HandBaselineDistance(
|
|||
return distance;
|
||||
}
|
||||
|
||||
Bound CalculateBound(const NormalizedLandmarkList& list) {
|
||||
Rect CalculateBound(const NormalizedLandmarkList& list) {
|
||||
constexpr float kMinInitialValue = std::numeric_limits<float>::max();
|
||||
constexpr float kMaxInitialValue = std::numeric_limits<float>::lowest();
|
||||
|
||||
|
@ -172,7 +172,7 @@ class HandDuplicatesFinder : public DuplicatesFinder {
|
|||
const int num = multi_landmarks.size();
|
||||
std::vector<float> baseline_distances;
|
||||
baseline_distances.reserve(num);
|
||||
std::vector<Bound> bounds;
|
||||
std::vector<Rect> bounds;
|
||||
bounds.reserve(num);
|
||||
for (const NormalizedLandmarkList& list : multi_landmarks) {
|
||||
ASSIGN_OR_RETURN(const float baseline_distance,
|
||||
|
|
|
@ -29,10 +29,14 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/status_macros.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/gate.h"
|
||||
#include "mediapipe/tasks/cc/core/model_asset_bundle_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarker_graph_options.pb.h"
|
||||
|
@ -50,6 +54,8 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::utils::DisallowIf;
|
||||
using ::mediapipe::tasks::core::ModelAssetBundleResources;
|
||||
using ::mediapipe::tasks::metadata::SetExternalFile;
|
||||
using ::mediapipe::tasks::vision::hand_detector::proto::
|
||||
HandDetectorGraphOptions;
|
||||
using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
||||
|
@ -58,6 +64,7 @@ using ::mediapipe::tasks::vision::hand_landmarker::proto::
|
|||
HandLandmarksDetectorGraphOptions;
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
constexpr char kHandRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
|
||||
|
@ -65,6 +72,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
|
|||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||
constexpr char kPalmRectsTag[] = "PALM_RECTS";
|
||||
constexpr char kPreviousLoopbackCalculatorName[] = "PreviousLoopbackCalculator";
|
||||
constexpr char kHandDetectorTFLiteName[] = "hand_detector.tflite";
|
||||
constexpr char kHandLandmarksDetectorTFLiteName[] =
|
||||
"hand_landmarks_detector.tflite";
|
||||
|
||||
struct HandLandmarkerOutputs {
|
||||
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
|
||||
|
@ -76,6 +86,27 @@ struct HandLandmarkerOutputs {
|
|||
Source<Image> image;
|
||||
};
|
||||
|
||||
// Sets the base options in the sub tasks.
|
||||
absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
||||
HandLandmarkerGraphOptions* options,
|
||||
bool is_copy) {
|
||||
ASSIGN_OR_RETURN(const auto hand_detector_file,
|
||||
resources.GetModelFile(kHandDetectorTFLiteName));
|
||||
SetExternalFile(hand_detector_file,
|
||||
options->mutable_hand_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file,
|
||||
resources.GetModelFile(kHandLandmarksDetectorTFLiteName));
|
||||
SetExternalFile(hand_landmarks_detector_file,
|
||||
options->mutable_hand_landmarks_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset(),
|
||||
is_copy);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph" performs hand
|
||||
|
@ -92,6 +123,9 @@ struct HandLandmarkerOutputs {
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform hand landmarks detection on.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// Describes image rotation and region of image to perform landmarks
|
||||
// detection on.
|
||||
//
|
||||
// Outputs:
|
||||
// LANDMARKS: - std::vector<NormalizedLandmarkList>
|
||||
|
@ -110,11 +144,14 @@ struct HandLandmarkerOutputs {
|
|||
// IMAGE - Image
|
||||
// The input image that the hand landmarker runs on and has the pixel data
|
||||
// stored on the target storage (CPU vs GPU).
|
||||
// All returned coordinates are in the unrotated and uncropped input image
|
||||
// coordinates system.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// input_stream: "NORM_RECT:norm_rect"
|
||||
// output_stream: "LANDMARKS:hand_landmarks"
|
||||
// output_stream: "WORLD_LANDMARKS:world_hand_landmarks"
|
||||
// output_stream: "HAND_RECT_NEXT_FRAME:hand_rect_next_frame"
|
||||
|
@ -154,10 +191,25 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
|||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_landmarker_outputs,
|
||||
BuildHandLandmarkerGraph(sc->Options<HandLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)], graph));
|
||||
if (sc->Options<HandLandmarkerGraphOptions>()
|
||||
.base_options()
|
||||
.has_model_asset()) {
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_asset_bundle_resources,
|
||||
CreateModelAssetBundleResources<HandLandmarkerGraphOptions>(sc));
|
||||
// Copies the file content instead of passing the pointer of file in
|
||||
// memory if the subgraph model resource service is not available.
|
||||
MP_RETURN_IF_ERROR(SetSubTaskBaseOptions(
|
||||
*model_asset_bundle_resources,
|
||||
sc->MutableOptions<HandLandmarkerGraphOptions>(),
|
||||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable()));
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto hand_landmarker_outputs,
|
||||
BuildHandLandmarkerGraph(
|
||||
sc->Options<HandLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
hand_landmarker_outputs.landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
hand_landmarker_outputs.world_landmark_lists >>
|
||||
|
@ -196,7 +248,7 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
|||
// graph: the mediapipe graph instance to be updated.
|
||||
absl::StatusOr<HandLandmarkerOutputs> BuildHandLandmarkerGraph(
|
||||
const HandLandmarkerGraphOptions& tasks_options, Source<Image> image_in,
|
||||
Graph& graph) {
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
const int max_num_hands =
|
||||
tasks_options.hand_detector_graph_options().num_hands();
|
||||
|
||||
|
@ -214,12 +266,15 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
|||
|
||||
auto image_for_hand_detector =
|
||||
DisallowIf(image_in, has_enough_hands, graph);
|
||||
auto norm_rect_in_for_hand_detector =
|
||||
DisallowIf(norm_rect_in, has_enough_hands, graph);
|
||||
|
||||
auto& hand_detector =
|
||||
graph.AddNode("mediapipe.tasks.vision.hand_detector.HandDetectorGraph");
|
||||
hand_detector.GetOptions<HandDetectorGraphOptions>().CopyFrom(
|
||||
tasks_options.hand_detector_graph_options());
|
||||
image_for_hand_detector >> hand_detector.In("IMAGE");
|
||||
norm_rect_in_for_hand_detector >> hand_detector.In("NORM_RECT");
|
||||
auto hand_rects_from_hand_detector = hand_detector.Out("HAND_RECTS");
|
||||
|
||||
auto& hand_association = graph.AddNode("HandAssociationCalculator");
|
||||
|
|
|
@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
@ -65,12 +67,14 @@ using ::testing::proto::Approximately;
|
|||
using ::testing::proto::Partially;
|
||||
|
||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||
constexpr char kPalmDetectionModel[] = "palm_detection_full.tflite";
|
||||
constexpr char kHandLandmarkerFullModel[] = "hand_landmark_full.tflite";
|
||||
constexpr char kHandLandmarkerModelBundle[] = "hand_landmark.task";
|
||||
constexpr char kLeftHandsImage[] = "left_hands.jpg";
|
||||
constexpr char kLeftHandsRotatedImage[] = "left_hands_rotated.jpg";
|
||||
|
||||
constexpr char kImageTag[] = "IMAGE";
|
||||
constexpr char kImageName[] = "image_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kNormRectName[] = "norm_rect_in";
|
||||
constexpr char kLandmarksTag[] = "LANDMARKS";
|
||||
constexpr char kLandmarksName[] = "landmarks";
|
||||
constexpr char kWorldLandmarksTag[] = "WORLD_LANDMARKS";
|
||||
|
@ -85,6 +89,11 @@ constexpr char kExpectedLeftUpHandLandmarksFilename[] =
|
|||
"expected_left_up_hand_landmarks.prototxt";
|
||||
constexpr char kExpectedLeftDownHandLandmarksFilename[] =
|
||||
"expected_left_down_hand_landmarks.prototxt";
|
||||
// Same but for the rotated image.
|
||||
constexpr char kExpectedLeftUpHandRotatedLandmarksFilename[] =
|
||||
"expected_left_up_hand_rotated_landmarks.prototxt";
|
||||
constexpr char kExpectedLeftDownHandRotatedLandmarksFilename[] =
|
||||
"expected_left_down_hand_rotated_landmarks.prototxt";
|
||||
|
||||
constexpr float kFullModelFractionDiff = 0.03; // percentage
|
||||
constexpr float kAbsMargin = 0.03;
|
||||
|
@ -105,21 +114,15 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() {
|
|||
"mediapipe.tasks.vision.hand_landmarker.HandLandmarkerGraph");
|
||||
auto& options =
|
||||
hand_landmarker_graph.GetOptions<HandLandmarkerGraphOptions>();
|
||||
options.mutable_hand_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(JoinPath("./", kTestDataDirectory, kPalmDetectionModel));
|
||||
options.mutable_hand_detector_graph_options()->mutable_base_options();
|
||||
options.mutable_base_options()->mutable_model_asset()->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerModelBundle));
|
||||
options.mutable_hand_detector_graph_options()->set_num_hands(kMaxNumHands);
|
||||
options.mutable_hand_landmarks_detector_graph_options()
|
||||
->mutable_base_options()
|
||||
->mutable_model_asset()
|
||||
->set_file_name(
|
||||
JoinPath("./", kTestDataDirectory, kHandLandmarkerFullModel));
|
||||
options.set_min_tracking_confidence(kMinTrackingConfidence);
|
||||
|
||||
graph[Input<Image>(kImageTag)].SetName(kImageName) >>
|
||||
hand_landmarker_graph.In(kImageTag);
|
||||
graph[Input<NormalizedRect>(kNormRectTag)].SetName(kNormRectName) >>
|
||||
hand_landmarker_graph.In(kNormRectTag);
|
||||
hand_landmarker_graph.Out(kLandmarksTag).SetName(kLandmarksName) >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
hand_landmarker_graph.Out(kWorldLandmarksTag).SetName(kWorldLandmarksName) >>
|
||||
|
@ -139,9 +142,16 @@ TEST_F(HandLandmarkerTest, Succeeds) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kLeftHandsImage)));
|
||||
NormalizedRect input_norm_rect;
|
||||
input_norm_rect.set_x_center(0.5);
|
||||
input_norm_rect.set_y_center(0.5);
|
||||
input_norm_rect.set_width(1.0);
|
||||
input_norm_rect.set_height(1.0);
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
|
||||
auto output_packets =
|
||||
task_runner->Process({{kImageName, MakePacket<Image>(std::move(image))}});
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName,
|
||||
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
|
||||
const auto& landmarks = (*output_packets)[kLandmarksName]
|
||||
.Get<std::vector<NormalizedLandmarkList>>();
|
||||
ASSERT_EQ(landmarks.size(), kMaxNumHands);
|
||||
|
@ -159,6 +169,38 @@ TEST_F(HandLandmarkerTest, Succeeds) {
|
|||
/*fraction=*/kFullModelFractionDiff));
|
||||
}
|
||||
|
||||
TEST_F(HandLandmarkerTest, SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory,
|
||||
kLeftHandsRotatedImage)));
|
||||
NormalizedRect input_norm_rect;
|
||||
input_norm_rect.set_x_center(0.5);
|
||||
input_norm_rect.set_y_center(0.5);
|
||||
input_norm_rect.set_width(1.0);
|
||||
input_norm_rect.set_height(1.0);
|
||||
input_norm_rect.set_rotation(M_PI / 2.0);
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto task_runner, CreateTaskRunner());
|
||||
auto output_packets = task_runner->Process(
|
||||
{{kImageName, MakePacket<Image>(std::move(image))},
|
||||
{kNormRectName,
|
||||
MakePacket<NormalizedRect>(std::move(input_norm_rect))}});
|
||||
const auto& landmarks = (*output_packets)[kLandmarksName]
|
||||
.Get<std::vector<NormalizedLandmarkList>>();
|
||||
ASSERT_EQ(landmarks.size(), kMaxNumHands);
|
||||
std::vector<NormalizedLandmarkList> expected_landmarks = {
|
||||
GetExpectedLandmarkList(kExpectedLeftUpHandRotatedLandmarksFilename),
|
||||
GetExpectedLandmarkList(kExpectedLeftDownHandRotatedLandmarksFilename)};
|
||||
|
||||
EXPECT_THAT(landmarks[0],
|
||||
Approximately(Partially(EqualsProto(expected_landmarks[0])),
|
||||
/*margin=*/kAbsMargin,
|
||||
/*fraction=*/kFullModelFractionDiff));
|
||||
EXPECT_THAT(landmarks[1],
|
||||
Approximately(Partially(EqualsProto(expected_landmarks[1])),
|
||||
/*margin=*/kAbsMargin,
|
||||
/*fraction=*/kFullModelFractionDiff));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace hand_landmarker
|
||||
|
|
|
@ -29,8 +29,8 @@ message HandLandmarkerGraphOptions {
|
|||
extend mediapipe.CalculatorOptions {
|
||||
optional HandLandmarkerGraphOptions ext = 462713202;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the model
|
||||
// asset bundle file with metadata, accelerator options, etc.
|
||||
optional core.proto.BaseOptions base_options = 1;
|
||||
|
||||
// Options for hand detector graph.
|
||||
|
|
|
@ -94,7 +94,7 @@ cc_library(
|
|||
name = "landmarks_utils",
|
||||
srcs = ["landmarks_utils.cc"],
|
||||
hdrs = ["landmarks_utils.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/containers:landmarks_detection"],
|
||||
deps = ["//mediapipe/tasks/cc/components/containers:rect"],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
|
@ -103,6 +103,6 @@ cc_test(
|
|||
deps = [
|
||||
":landmarks_utils",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/tasks/cc/components/containers:landmarks_detection",
|
||||
"//mediapipe/tasks/cc/components/containers:rect",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,15 +18,17 @@ limitations under the License.
|
|||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::vision::utils {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::Bound;
|
||||
using ::mediapipe::tasks::components::containers::Rect;
|
||||
|
||||
float CalculateArea(const Bound& bound) {
|
||||
return (bound.right - bound.left) * (bound.bottom - bound.top);
|
||||
float CalculateArea(const Rect& rect) {
|
||||
return (rect.right - rect.left) * (rect.bottom - rect.top);
|
||||
}
|
||||
|
||||
float CalculateIntersectionArea(const Bound& a, const Bound& b) {
|
||||
float CalculateIntersectionArea(const Rect& a, const Rect& b) {
|
||||
const float intersection_left = std::max<float>(a.left, b.left);
|
||||
const float intersection_top = std::max<float>(a.top, b.top);
|
||||
const float intersection_right = std::min<float>(a.right, b.right);
|
||||
|
@ -36,7 +38,7 @@ float CalculateIntersectionArea(const Bound& a, const Bound& b) {
|
|||
std::max<float>(intersection_right - intersection_left, 0.0);
|
||||
}
|
||||
|
||||
float CalculateIOU(const Bound& a, const Bound& b) {
|
||||
float CalculateIOU(const Rect& a, const Rect& b) {
|
||||
const float area_a = CalculateArea(a);
|
||||
const float area_b = CalculateArea(b);
|
||||
if (area_a <= 0 || area_b <= 0) return 0.0;
|
||||
|
|
|
@ -22,20 +22,20 @@ limitations under the License.
|
|||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/landmarks_detection.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
|
||||
|
||||
namespace mediapipe::tasks::vision::utils {
|
||||
|
||||
// Calculates intersection over union for two bounds.
|
||||
float CalculateIOU(const components::containers::Bound& a,
|
||||
const components::containers::Bound& b);
|
||||
float CalculateIOU(const components::containers::Rect& a,
|
||||
const components::containers::Rect& b);
|
||||
|
||||
// Calculates area for face bound
|
||||
float CalculateArea(const components::containers::Bound& bound);
|
||||
float CalculateArea(const components::containers::Rect& rect);
|
||||
|
||||
// Calucates intersection area of two face bounds
|
||||
float CalculateIntersectionArea(const components::containers::Bound& a,
|
||||
const components::containers::Bound& b);
|
||||
float CalculateIntersectionArea(const components::containers::Rect& a,
|
||||
const components::containers::Rect& b);
|
||||
} // namespace mediapipe::tasks::vision::utils
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_VISION_UTILS_LANDMARKS_UTILS_H_
|
||||
|
|
|
@ -30,13 +30,13 @@ public abstract class Landmark {
|
|||
return new AutoValue_Landmark(x, y, z, normalized);
|
||||
}
|
||||
|
||||
// The x coordniates of the landmark.
|
||||
// The x coordinates of the landmark.
|
||||
public abstract float x();
|
||||
|
||||
// The y coordniates of the landmark.
|
||||
// The y coordinates of the landmark.
|
||||
public abstract float y();
|
||||
|
||||
// The z coordniates of the landmark.
|
||||
// The z coordinates of the landmark.
|
||||
public abstract float z();
|
||||
|
||||
// Whether this landmark is normalized with respect to the image size.
|
||||
|
|
|
@ -117,7 +117,7 @@ public class OutputHandler<OutputT extends TaskResult, InputT> {
|
|||
if (errorListener != null) {
|
||||
errorListener.onError(e);
|
||||
} else {
|
||||
Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e);
|
||||
Log.e(TAG, "Error occurs when getting MediaPipe task result. " + e);
|
||||
}
|
||||
} finally {
|
||||
for (Packet packet : packets) {
|
||||
|
|
63
mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD
Normal file
63
mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
# The native library of all MediaPipe text tasks.
|
||||
cc_binary(
|
||||
name = "libmediapipe_tasks_text_jni.so",
|
||||
linkshared = 1,
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
|
||||
"//mediapipe/tasks/cc/text/text_classifier:text_classifier_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "libmediapipe_tasks_text_jni_lib",
|
||||
srcs = [":libmediapipe_tasks_text_jni.so"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "textclassifier",
|
||||
srcs = [
|
||||
"textclassifier/TextClassificationResult.java",
|
||||
"textclassifier/TextClassifier.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "textclassifier/AndroidManifest.xml",
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classification_entry",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:classifications",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/text:libmediapipe_tasks_text_jni_lib",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.text.textclassifier">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.tasks.components.container.proto.CategoryProto;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.components.containers.ClassificationEntry;
|
||||
import com.google.mediapipe.tasks.components.containers.Classifications;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the classification results generated by {@link TextClassifier}. */
|
||||
@AutoValue
|
||||
public abstract class TextClassificationResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link TextClassificationResult} instance from a {@link
|
||||
* ClassificationsProto.ClassificationResult} protobuf message.
|
||||
*
|
||||
* @param classificationResult a {@link ClassificationsProto.ClassificationResult} protobuf
|
||||
* message.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static TextClassificationResult create(
|
||||
ClassificationsProto.ClassificationResult classificationResult, long timestampMs) {
|
||||
List<Classifications> classifications = new ArrayList<>();
|
||||
for (ClassificationsProto.Classifications classificationsProto :
|
||||
classificationResult.getClassificationsList()) {
|
||||
classifications.add(classificationsFromProto(classificationsProto));
|
||||
}
|
||||
return new AutoValue_TextClassificationResult(
|
||||
timestampMs, Collections.unmodifiableList(classifications));
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
|
||||
/** Contains one set of results per classifier head. */
|
||||
@SuppressWarnings("AutoValueImmutableFields")
|
||||
public abstract List<Classifications> classifications();
|
||||
|
||||
/**
|
||||
* Converts a {@link CategoryProto.Category} protobuf message to a {@link Category} object.
|
||||
*
|
||||
* @param category the {@link CategoryProto.Category} protobuf message to convert.
|
||||
*/
|
||||
static Category categoryFromProto(CategoryProto.Category category) {
|
||||
return Category.create(
|
||||
category.getScore(),
|
||||
category.getIndex(),
|
||||
category.getCategoryName(),
|
||||
category.getDisplayName());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.ClassificationEntry} protobuf message to a {@link
|
||||
* ClassificationEntry} object.
|
||||
*
|
||||
* @param entry the {@link ClassificationsProto.ClassificationEntry} protobuf message to convert.
|
||||
*/
|
||||
static ClassificationEntry classificationEntryFromProto(
|
||||
ClassificationsProto.ClassificationEntry entry) {
|
||||
List<Category> categories = new ArrayList<>();
|
||||
for (CategoryProto.Category category : entry.getCategoriesList()) {
|
||||
categories.add(categoryFromProto(category));
|
||||
}
|
||||
return ClassificationEntry.create(categories, entry.getTimestampMs());
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a {@link ClassificationsProto.Classifications} protobuf message to a {@link
|
||||
* Classifications} object.
|
||||
*
|
||||
* @param classifications the {@link ClassificationsProto.Classifications} protobuf message to
|
||||
* convert.
|
||||
*/
|
||||
static Classifications classificationsFromProto(
|
||||
ClassificationsProto.Classifications classifications) {
|
||||
List<ClassificationEntry> entries = new ArrayList<>();
|
||||
for (ClassificationsProto.ClassificationEntry entry : classifications.getEntriesList()) {
|
||||
entries.add(classificationEntryFromProto(entry));
|
||||
}
|
||||
return Classifications.create(
|
||||
entries, classifications.getHeadIndex(), classifications.getHeadName());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,253 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import android.content.Context;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.ProtoUtil;
|
||||
import com.google.mediapipe.tasks.components.container.proto.ClassificationsProto;
|
||||
import com.google.mediapipe.tasks.components.processors.ClassifierOptions;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.text.textclassifier.proto.TextClassifierGraphOptionsProto;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs classification on text.
|
||||
*
|
||||
* <p>This API expects a TFLite model with (optional) <a
|
||||
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata</a> that contains
|
||||
* the mandatory (described below) input tensors, output tensor, and the optional (but recommended)
|
||||
* label items as AssociatedFiles with type TENSOR_AXIS_LABELS per output classification tensor.
|
||||
*
|
||||
* <p>Metadata is required for models with int32 input tensors because it contains the input process
|
||||
* unit for the model's Tokenizer. No metadata is required for models with string input tensors.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input tensors
|
||||
* <ul>
|
||||
* <li>Three input tensors ({@code kTfLiteInt32}) of shape {@code [batch_size x
|
||||
* bert_max_seq_len]} representing the input ids, mask ids, and segment ids. This input
|
||||
* signature requires a Bert Tokenizer process unit in the model metadata.
|
||||
* <li>Or one input tensor ({@code kTfLiteInt32}) of shape {@code [batch_size x
|
||||
* max_seq_len]} representing the input ids. This input signature requires a Regex
|
||||
* Tokenizer process unit in the model metadata.
|
||||
* <li>Or one input tensor ({@code kTfLiteString}) that is shapeless or has shape {@code
|
||||
* [1]} containing the input string.
|
||||
* </ul>
|
||||
* <li>At least one output tensor ({@code kTfLiteFloat32}/{@code kBool}) with:
|
||||
* <ul>
|
||||
* <li>{@code N} classes and shape {@code [1 x N]}
|
||||
* <li>optional (but recommended) label map(s) as AssociatedFile-s with type
|
||||
* TENSOR_AXIS_LABELS, containing one label per line. The first such AssociatedFile (if
|
||||
* any) is used to fill the {@code class_name} field of the results. The {@code
|
||||
* display_name} field is filled from the AssociatedFile (if any) whose locale matches
|
||||
* the {@code display_names_locale} field of the {@code TextClassifierOptions} used at
|
||||
* creation time ("en" by default, i.e. English). If none of these are available, only
|
||||
* the {@code index} field of the results will be filled.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*/
|
||||
public final class TextClassifier implements AutoCloseable {
|
||||
private static final String TAG = TextClassifier.class.getSimpleName();
|
||||
private static final String TEXT_IN_STREAM_NAME = "text_in";
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("TEXT:" + TEXT_IN_STREAM_NAME));
|
||||
|
||||
@SuppressWarnings("ConstantCaseForConstants")
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("CLASSIFICATION_RESULT:classification_result_out"));
|
||||
|
||||
private static final int CLASSIFICATION_RESULT_OUT_STREAM_INDEX = 0;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.text.text_classifier.TextClassifierGraph";
|
||||
private final TaskRunner runner;
|
||||
|
||||
static {
|
||||
System.loadLibrary("mediapipe_tasks_text_jni");
|
||||
ProtoUtil.registerTypeName(
|
||||
ClassificationsProto.ClassificationResult.class,
|
||||
"mediapipe.tasks.components.containers.proto.ClassificationResult");
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link TextClassifier} instance from a model file and the default {@link
|
||||
* TextClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelPath path to the text model with metadata in the assets.
|
||||
* @throws MediaPipeException if there is is an error during {@link TextClassifier} creation.
|
||||
*/
|
||||
public static TextClassifier createFromFile(Context context, String modelPath) {
|
||||
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
|
||||
return createFromOptions(
|
||||
context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link TextClassifier} instance from a model file and the default {@link
|
||||
* TextClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param modelFile the text model {@link File} instance.
|
||||
* @throws IOException if an I/O error occurs when opening the tflite model file.
|
||||
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
|
||||
*/
|
||||
public static TextClassifier createFromFile(Context context, File modelFile) throws IOException {
|
||||
try (ParcelFileDescriptor descriptor =
|
||||
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
|
||||
BaseOptions baseOptions =
|
||||
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
|
||||
return createFromOptions(
|
||||
context, TextClassifierOptions.builder().setBaseOptions(baseOptions).build());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link TextClassifier} instance from {@link TextClassifierOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param options a {@link TextClassifierOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link TextClassifier} creation.
|
||||
*/
|
||||
public static TextClassifier createFromOptions(Context context, TextClassifierOptions options) {
|
||||
OutputHandler<TextClassificationResult, Void> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<TextClassificationResult, Void>() {
|
||||
@Override
|
||||
public TextClassificationResult convertToTaskResult(List<Packet> packets) {
|
||||
try {
|
||||
return TextClassificationResult.create(
|
||||
PacketGetter.getProto(
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
||||
} catch (IOException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void convertToTaskInput(List<Packet> packets) {
|
||||
return null;
|
||||
}
|
||||
});
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<TextClassifierOptions>builder()
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(options)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
handler);
|
||||
return new TextClassifier(runner);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize a {@link TextClassifier} from a {@link TaskRunner}.
|
||||
*
|
||||
* @param runner a {@link TaskRunner}.
|
||||
*/
|
||||
private TextClassifier(TaskRunner runner) {
|
||||
this.runner = runner;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs classification on the input text.
|
||||
*
|
||||
* @param inputText a {@link String} for processing.
|
||||
*/
|
||||
public TextClassificationResult classify(String inputText) {
|
||||
Map<String, Packet> inputPackets = new HashMap<>();
|
||||
inputPackets.put(TEXT_IN_STREAM_NAME, runner.getPacketCreator().createString(inputText));
|
||||
return (TextClassificationResult) runner.process(inputPackets);
|
||||
}
|
||||
|
||||
/** Closes and cleans up the {@link TextClassifier}. */
|
||||
@Override
|
||||
public void close() {
|
||||
runner.close();
|
||||
}
|
||||
|
||||
/** Options for setting up a {@link TextClassifier}. */
|
||||
@AutoValue
|
||||
public abstract static class TextClassifierOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link TextClassifierOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the text classifier task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the optional {@link ClassifierOptions} controling classification behavior, such as
|
||||
* score threshold, number of results, etc.
|
||||
*/
|
||||
public abstract Builder setClassifierOptions(ClassifierOptions classifierOptions);
|
||||
|
||||
public abstract TextClassifierOptions build();
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract Optional<ClassifierOptions> classifierOptions();
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_TextClassifier_TextClassifierOptions.Builder();
|
||||
}
|
||||
|
||||
/** Converts a {@link TextClassifierOptions} to a {@link CalculatorOptions} protobuf message. */
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
|
||||
BaseOptionsProto.BaseOptions.newBuilder();
|
||||
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.Builder taskOptionsBuilder =
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.newBuilder()
|
||||
.setBaseOptions(baseOptionsBuilder);
|
||||
if (classifierOptions().isPresent()) {
|
||||
taskOptionsBuilder.setClassifierOptions(classifierOptions().get().convertToProto());
|
||||
}
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
TextClassifierGraphOptionsProto.TextClassifierGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -15,6 +15,7 @@
|
|||
package com.google.mediapipe.tasks.vision.gesturerecognizer;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.RectF;
|
||||
import android.os.ParcelFileDescriptor;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList;
|
||||
|
@ -71,8 +72,10 @@ import java.util.Optional;
|
|||
public final class GestureRecognizer extends BaseVisionTaskApi {
|
||||
private static final String TAG = GestureRecognizer.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
|
@ -205,7 +208,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private GestureRecognizer(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME);
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -223,7 +226,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public GestureRecognitionResult recognize(Image inputImage) {
|
||||
return (GestureRecognitionResult) processImageData(inputImage);
|
||||
// TODO: add proper support for rotations.
|
||||
return (GestureRecognitionResult) processImageData(inputImage, buildFullImageRectF());
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -244,7 +248,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public GestureRecognitionResult recognizeForVideo(Image inputImage, long inputTimestampMs) {
|
||||
return (GestureRecognitionResult) processVideoData(inputImage, inputTimestampMs);
|
||||
// TODO: add proper support for rotations.
|
||||
return (GestureRecognitionResult)
|
||||
processVideoData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -266,7 +272,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void recognizeAsync(Image inputImage, long inputTimestampMs) {
|
||||
sendLiveStreamData(inputImage, inputTimestampMs);
|
||||
// TODO: add proper support for rotations.
|
||||
sendLiveStreamData(inputImage, buildFullImageRectF(), inputTimestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link GestureRecognizer}. */
|
||||
|
@ -303,18 +310,18 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
/** Sets the maximum number of hands can be detected by the GestureRecognizer. */
|
||||
public abstract Builder setNumHands(Integer value);
|
||||
|
||||
/** Sets minimum confidence score for the hand detection to be considered successfully */
|
||||
/** Sets minimum confidence score for the hand detection to be considered successful */
|
||||
public abstract Builder setMinHandDetectionConfidence(Float value);
|
||||
|
||||
/** Sets minimum confidence score of hand presence score in the hand landmark detection. */
|
||||
public abstract Builder setMinHandPresenceConfidence(Float value);
|
||||
|
||||
/** Sets the minimum confidence score for the hand tracking to be considered successfully. */
|
||||
/** Sets the minimum confidence score for the hand tracking to be considered successful. */
|
||||
public abstract Builder setMinTrackingConfidence(Float value);
|
||||
|
||||
/**
|
||||
* Sets the minimum confidence score for the gestures to be considered successfully. If < 0,
|
||||
* the gesture confidence threshold=0.5 for the model is used.
|
||||
* Sets the minimum confidence score for the gestures to be considered successful. If < 0, the
|
||||
* gesture confidence threshold=0.5 for the model is used.
|
||||
*
|
||||
* <p>TODO Note this option is subject to change, after scoring merging
|
||||
* calculator is implemented.
|
||||
|
@ -433,8 +440,7 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
HandLandmarkerGraphOptionsProto.HandLandmarkerGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
||||
.mergeFrom(convertBaseOptionsToProto(baseOptionsHandLandmarker())));
|
||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE));
|
||||
minTrackingConfidence()
|
||||
.ifPresent(handLandmarkerGraphOptionsBuilder::setMinTrackingConfidence);
|
||||
handLandmarkerGraphOptionsBuilder
|
||||
|
@ -465,4 +471,9 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/** Creates a RectF covering the full image. */
|
||||
private static RectF buildFullImageRectF() {
|
||||
return new RectF(0, 0, 1, 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,7 +39,6 @@ import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
|||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imageclassifier.proto.ImageClassifierGraphOptionsProto;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -176,7 +175,7 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX),
|
||||
ClassificationsProto.ClassificationResult.getDefaultInstance()),
|
||||
packets.get(CLASSIFICATION_RESULT_OUT_STREAM_INDEX).getTimestamp());
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
} catch (IOException e) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(), e.getMessage());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.text.textclassifiertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="textclassifiertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.text.textclassifiertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.text.textclassifier;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.tasks.components.containers.Category;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.TestUtils;
|
||||
import com.google.mediapipe.tasks.text.textclassifier.TextClassifier.TextClassifierOptions;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
|
||||
/** Test for {@link TextClassifier}/ */
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public class TextClassifierTest {
|
||||
private static final String BERT_MODEL_FILE = "bert_text_classifier.tflite";
|
||||
private static final String REGEX_MODEL_FILE =
|
||||
"test_model_text_classifier_with_regex_tokenizer.tflite";
|
||||
private static final String STRING_TO_BOOL_MODEL_FILE =
|
||||
"test_model_text_classifier_bool_output.tflite";
|
||||
private static final String NEGATIVE_TEXT = "unflinchingly bleak and desperate";
|
||||
private static final String POSITIVE_TEXT = "it's a charming and often affecting journey";
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingModel() throws Exception {
|
||||
String nonExistentFile = "/path/to/non/existent/file";
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), nonExistentFile));
|
||||
assertThat(exception).hasMessageThat().contains(nonExistentFile);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void create_failsWithMissingOpResolver() throws Exception {
|
||||
TextClassifierOptions options =
|
||||
TextClassifierOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(STRING_TO_BOOL_MODEL_FILE).build())
|
||||
.build();
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
TextClassifier.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options));
|
||||
// TODO: Make MediaPipe InferenceCalculator report the detailed.
|
||||
// interpreter errors (e.g., "Encountered unresolved custom op").
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("interpreter_builder(&interpreter) == kTfLiteOk");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithBert() throws Exception {
|
||||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.95630914f, 0, "negative", ""),
|
||||
Category.create(0.04369091f, 1, "positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.99997187f, 1, "positive", ""),
|
||||
Category.create(2.8132641E-5f, 0, "negative", "")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithFileObject() throws Exception {
|
||||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(),
|
||||
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE));
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.95630914f, 0, "negative", ""),
|
||||
Category.create(0.04369091f, 1, "positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.99997187f, 1, "positive", ""),
|
||||
Category.create(2.8132641E-5f, 0, "negative", "")));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithRegex() throws Exception {
|
||||
TextClassifier textClassifier =
|
||||
TextClassifier.createFromFile(
|
||||
ApplicationProvider.getApplicationContext(), REGEX_MODEL_FILE);
|
||||
TextClassificationResult negativeResults = textClassifier.classify(NEGATIVE_TEXT);
|
||||
assertHasOneHead(negativeResults);
|
||||
assertCategoriesAre(
|
||||
negativeResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.6647746f, 0, "Negative", ""),
|
||||
Category.create(0.33522537f, 1, "Positive", "")));
|
||||
|
||||
TextClassificationResult positiveResults = textClassifier.classify(POSITIVE_TEXT);
|
||||
assertHasOneHead(positiveResults);
|
||||
assertCategoriesAre(
|
||||
positiveResults,
|
||||
Arrays.asList(
|
||||
Category.create(0.5120041f, 0, "Negative", ""),
|
||||
Category.create(0.48799595f, 1, "Positive", "")));
|
||||
}
|
||||
|
||||
private static void assertHasOneHead(TextClassificationResult results) {
|
||||
assertThat(results.classifications()).hasSize(1);
|
||||
assertThat(results.classifications().get(0).headIndex()).isEqualTo(0);
|
||||
assertThat(results.classifications().get(0).headName()).isEqualTo("probability");
|
||||
assertThat(results.classifications().get(0).entries()).hasSize(1);
|
||||
}
|
||||
|
||||
private static void assertCategoriesAre(
|
||||
TextClassificationResult results, List<Category> categories) {
|
||||
assertThat(results.classifications().get(0).entries().get(0).categories())
|
||||
.isEqualTo(categories);
|
||||
}
|
||||
}
|
|
@ -23,5 +23,9 @@ py_library(
|
|||
testonly = 1,
|
||||
srcs = ["test_utils.py"],
|
||||
srcs_version = "PY3",
|
||||
visibility = [
|
||||
"//mediapipe/model_maker/python/vision/gesture_recognizer:__pkg__",
|
||||
"//mediapipe/tasks:internal",
|
||||
],
|
||||
deps = ["//mediapipe/python:_framework_bindings"],
|
||||
)
|
||||
|
|
15
mediapipe/tasks/testdata/vision/BUILD
vendored
15
mediapipe/tasks/testdata/vision/BUILD
vendored
|
@ -35,9 +35,11 @@ mediapipe_files(srcs = [
|
|||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
||||
"deeplabv3.tflite",
|
||||
"hand_landmark.task",
|
||||
"hand_landmark_full.tflite",
|
||||
"hand_landmark_lite.tflite",
|
||||
"left_hands.jpg",
|
||||
"left_hands_rotated.jpg",
|
||||
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
|
||||
"mobilenet_v1_0.25_224_1_default_1.tflite",
|
||||
"mobilenet_v1_0.25_224_1_metadata_1.tflite",
|
||||
|
@ -51,7 +53,9 @@ mediapipe_files(srcs = [
|
|||
"multi_objects_rotated.jpg",
|
||||
"palm_detection_full.tflite",
|
||||
"pointing_up.jpg",
|
||||
"pointing_up_rotated.jpg",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
"segmentation_input_rotation0.jpg",
|
||||
"selfie_segm_128_128_3.tflite",
|
||||
|
@ -64,7 +68,9 @@ mediapipe_files(srcs = [
|
|||
exports_files(
|
||||
srcs = [
|
||||
"expected_left_down_hand_landmarks.prototxt",
|
||||
"expected_left_down_hand_rotated_landmarks.prototxt",
|
||||
"expected_left_up_hand_landmarks.prototxt",
|
||||
"expected_left_up_hand_rotated_landmarks.prototxt",
|
||||
"expected_right_down_hand_landmarks.prototxt",
|
||||
"expected_right_up_hand_landmarks.prototxt",
|
||||
],
|
||||
|
@ -84,11 +90,14 @@ filegroup(
|
|||
"hand_landmark_full.tflite",
|
||||
"hand_landmark_lite.tflite",
|
||||
"left_hands.jpg",
|
||||
"left_hands_rotated.jpg",
|
||||
"mozart_square.jpg",
|
||||
"multi_objects.jpg",
|
||||
"multi_objects_rotated.jpg",
|
||||
"pointing_up.jpg",
|
||||
"pointing_up_rotated.jpg",
|
||||
"right_hands.jpg",
|
||||
"right_hands_rotated.jpg",
|
||||
"segmentation_golden_rotation0.png",
|
||||
"segmentation_input_rotation0.jpg",
|
||||
"selfie_segm_128_128_3_expected_mask.jpg",
|
||||
|
@ -109,6 +118,7 @@ filegroup(
|
|||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite",
|
||||
"coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_with_dummy_score_calibration.tflite",
|
||||
"deeplabv3.tflite",
|
||||
"hand_landmark.task",
|
||||
"hand_landmark_full.tflite",
|
||||
"hand_landmark_lite.tflite",
|
||||
"mobilenet_v1_0.25_192_quantized_1_default_1.tflite",
|
||||
|
@ -129,12 +139,17 @@ filegroup(
|
|||
name = "test_protos",
|
||||
srcs = [
|
||||
"expected_left_down_hand_landmarks.prototxt",
|
||||
"expected_left_down_hand_rotated_landmarks.prototxt",
|
||||
"expected_left_up_hand_landmarks.prototxt",
|
||||
"expected_left_up_hand_rotated_landmarks.prototxt",
|
||||
"expected_right_down_hand_landmarks.prototxt",
|
||||
"expected_right_up_hand_landmarks.prototxt",
|
||||
"hand_detector_result_one_hand.pbtxt",
|
||||
"hand_detector_result_one_hand_rotated.pbtxt",
|
||||
"hand_detector_result_two_hands.pbtxt",
|
||||
"pointing_up_landmarks.pbtxt",
|
||||
"pointing_up_rotated_landmarks.pbtxt",
|
||||
"thumb_up_landmarks.pbtxt",
|
||||
"thumb_up_rotated_landmarks.pbtxt",
|
||||
],
|
||||
)
|
||||
|
|
84
mediapipe/tasks/testdata/vision/expected_left_down_hand_rotated_landmarks.prototxt
vendored
Normal file
84
mediapipe/tasks/testdata/vision/expected_left_down_hand_rotated_landmarks.prototxt
vendored
Normal file
|
@ -0,0 +1,84 @@
|
|||
landmark {
|
||||
x: 0.9259716
|
||||
y: 0.18969846
|
||||
}
|
||||
landmark {
|
||||
x: 0.88135517
|
||||
y: 0.28856543
|
||||
}
|
||||
landmark {
|
||||
x: 0.7600651
|
||||
y: 0.3578236
|
||||
}
|
||||
landmark {
|
||||
x: 0.62631166
|
||||
y: 0.40490413
|
||||
}
|
||||
landmark {
|
||||
x: 0.5374573
|
||||
y: 0.45170194
|
||||
}
|
||||
landmark {
|
||||
x: 0.57372385
|
||||
y: 0.29924914
|
||||
}
|
||||
landmark {
|
||||
x: 0.36731184
|
||||
y: 0.33081773
|
||||
}
|
||||
landmark {
|
||||
x: 0.24132833
|
||||
y: 0.34759054
|
||||
}
|
||||
landmark {
|
||||
x: 0.13690609
|
||||
y: 0.35727677
|
||||
}
|
||||
landmark {
|
||||
x: 0.5535803
|
||||
y: 0.2398035
|
||||
}
|
||||
landmark {
|
||||
x: 0.31834763
|
||||
y: 0.24999242
|
||||
}
|
||||
landmark {
|
||||
x: 0.16748133
|
||||
y: 0.25625145
|
||||
}
|
||||
landmark {
|
||||
x: 0.050747424
|
||||
y: 0.25991398
|
||||
}
|
||||
landmark {
|
||||
x: 0.56593156
|
||||
y: 0.1867483
|
||||
}
|
||||
landmark {
|
||||
x: 0.3543046
|
||||
y: 0.17923892
|
||||
}
|
||||
landmark {
|
||||
x: 0.21360746
|
||||
y: 0.17454882
|
||||
}
|
||||
landmark {
|
||||
x: 0.11110917
|
||||
y: 0.17232567
|
||||
}
|
||||
landmark {
|
||||
x: 0.5948908
|
||||
y: 0.14024714
|
||||
}
|
||||
landmark {
|
||||
x: 0.42692152
|
||||
y: 0.11949824
|
||||
}
|
||||
landmark {
|
||||
x: 0.32239118
|
||||
y: 0.106370345
|
||||
}
|
||||
landmark {
|
||||
x: 0.23672739
|
||||
y: 0.09432885
|
||||
}
|
84
mediapipe/tasks/testdata/vision/expected_left_up_hand_rotated_landmarks.prototxt
vendored
Normal file
84
mediapipe/tasks/testdata/vision/expected_left_up_hand_rotated_landmarks.prototxt
vendored
Normal file
|
@ -0,0 +1,84 @@
|
|||
landmark {
|
||||
x: 0.06676084
|
||||
y: 0.8095678
|
||||
}
|
||||
landmark {
|
||||
x: 0.11359626
|
||||
y: 0.71148247
|
||||
}
|
||||
landmark {
|
||||
x: 0.23572624
|
||||
y: 0.6414506
|
||||
}
|
||||
landmark {
|
||||
x: 0.37323278
|
||||
y: 0.5959156
|
||||
}
|
||||
landmark {
|
||||
x: 0.46243322
|
||||
y: 0.55125874
|
||||
}
|
||||
landmark {
|
||||
x: 0.4205411
|
||||
y: 0.69531494
|
||||
}
|
||||
landmark {
|
||||
x: 0.62798893
|
||||
y: 0.66715276
|
||||
}
|
||||
landmark {
|
||||
x: 0.7568023
|
||||
y: 0.65208924
|
||||
}
|
||||
landmark {
|
||||
x: 0.86370826
|
||||
y: 0.6437276
|
||||
}
|
||||
landmark {
|
||||
x: 0.445136
|
||||
y: 0.75394773
|
||||
}
|
||||
landmark {
|
||||
x: 0.6787485
|
||||
y: 0.745853
|
||||
}
|
||||
landmark {
|
||||
x: 0.8290694
|
||||
y: 0.7412988
|
||||
}
|
||||
landmark {
|
||||
x: 0.94454145
|
||||
y: 0.7384017
|
||||
}
|
||||
landmark {
|
||||
x: 0.43516788
|
||||
y: 0.8082166
|
||||
}
|
||||
landmark {
|
||||
x: 0.6459554
|
||||
y: 0.81768996
|
||||
}
|
||||
landmark {
|
||||
x: 0.7875173
|
||||
y: 0.825062
|
||||
}
|
||||
landmark {
|
||||
x: 0.89249825
|
||||
y: 0.82850707
|
||||
}
|
||||
landmark {
|
||||
x: 0.40665048
|
||||
y: 0.8567925
|
||||
}
|
||||
landmark {
|
||||
x: 0.57228816
|
||||
y: 0.8802181
|
||||
}
|
||||
landmark {
|
||||
x: 0.6762071
|
||||
y: 0.8941581
|
||||
}
|
||||
landmark {
|
||||
x: 0.76453924
|
||||
y: 0.90583205
|
||||
}
|
33
mediapipe/tasks/testdata/vision/hand_detector_result_one_hand_rotated.pbtxt
vendored
Normal file
33
mediapipe/tasks/testdata/vision/hand_detector_result_one_hand_rotated.pbtxt
vendored
Normal file
|
@ -0,0 +1,33 @@
|
|||
detections {
|
||||
label: "Palm"
|
||||
score: 0.97115
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.5198178
|
||||
ymin: 0.6467485
|
||||
width: 0.42467535
|
||||
height: 0.22546273
|
||||
}
|
||||
}
|
||||
}
|
||||
detections {
|
||||
label: "Palm"
|
||||
score: 0.96701413
|
||||
location_data {
|
||||
format: RELATIVE_BOUNDING_BOX
|
||||
relative_bounding_box {
|
||||
xmin: 0.024490356
|
||||
ymin: 0.12620124
|
||||
width: 0.43832153
|
||||
height: 0.23269764
|
||||
}
|
||||
}
|
||||
}
|
||||
hand_rects {
|
||||
x_center: 0.5760683
|
||||
y_center: 0.6829921
|
||||
height: 0.5862031
|
||||
width: 1.1048855
|
||||
rotation: -0.8250832
|
||||
}
|
BIN
mediapipe/tasks/testdata/vision/hand_landmark.task
vendored
Normal file
BIN
mediapipe/tasks/testdata/vision/hand_landmark.task
vendored
Normal file
Binary file not shown.
223
mediapipe/tasks/testdata/vision/pointing_up_rotated_landmarks.pbtxt
vendored
Normal file
223
mediapipe/tasks/testdata/vision/pointing_up_rotated_landmarks.pbtxt
vendored
Normal file
|
@ -0,0 +1,223 @@
|
|||
classifications {
|
||||
classification {
|
||||
score: 1.0
|
||||
label: "Left"
|
||||
display_name: "Left"
|
||||
}
|
||||
}
|
||||
|
||||
landmarks {
|
||||
landmark {
|
||||
x: 0.25546086
|
||||
y: 0.47584262
|
||||
z: 1.835341e-07
|
||||
}
|
||||
landmark {
|
||||
x: 0.3363011
|
||||
y: 0.54135
|
||||
z: -0.041144375
|
||||
}
|
||||
landmark {
|
||||
x: 0.4375146
|
||||
y: 0.57881975
|
||||
z: -0.06807727
|
||||
}
|
||||
landmark {
|
||||
x: 0.49603376
|
||||
y: 0.5263966
|
||||
z: -0.09387612
|
||||
}
|
||||
landmark {
|
||||
x: 0.5022822
|
||||
y: 0.4413827
|
||||
z: -0.1189948
|
||||
}
|
||||
landmark {
|
||||
x: 0.5569452
|
||||
y: 0.4724485
|
||||
z: -0.05138246
|
||||
}
|
||||
landmark {
|
||||
x: 0.6687125
|
||||
y: 0.47918057
|
||||
z: -0.09121969
|
||||
}
|
||||
landmark {
|
||||
x: 0.73666537
|
||||
y: 0.48318353
|
||||
z: -0.11703273
|
||||
}
|
||||
landmark {
|
||||
x: 0.7998315
|
||||
y: 0.4741413
|
||||
z: -0.1386424
|
||||
}
|
||||
landmark {
|
||||
x: 0.5244063
|
||||
y: 0.39292705
|
||||
z: -0.061040796
|
||||
}
|
||||
landmark {
|
||||
x: 0.57215345
|
||||
y: 0.41514704
|
||||
z: -0.11967233
|
||||
}
|
||||
landmark {
|
||||
x: 0.4724468
|
||||
y: 0.45553637
|
||||
z: -0.13287684
|
||||
}
|
||||
landmark {
|
||||
x: 0.43794966
|
||||
y: 0.45210314
|
||||
z: -0.13210714
|
||||
}
|
||||
landmark {
|
||||
x: 0.47838163
|
||||
y: 0.33329
|
||||
z: -0.07421263
|
||||
}
|
||||
landmark {
|
||||
x: 0.51081127
|
||||
y: 0.35479474
|
||||
z: -0.13596693
|
||||
}
|
||||
landmark {
|
||||
x: 0.42433846
|
||||
y: 0.40486792
|
||||
z: -0.121291734
|
||||
}
|
||||
landmark {
|
||||
x: 0.40280548
|
||||
y: 0.39977497
|
||||
z: -0.09928809
|
||||
}
|
||||
landmark {
|
||||
x: 0.42269367
|
||||
y: 0.2798249
|
||||
z: -0.09064263
|
||||
}
|
||||
landmark {
|
||||
x: 0.45849988
|
||||
y: 0.3069861
|
||||
z: -0.12894689
|
||||
}
|
||||
landmark {
|
||||
x: 0.40754712
|
||||
y: 0.35153976
|
||||
z: -0.109160855
|
||||
}
|
||||
landmark {
|
||||
x: 0.38855004
|
||||
y: 0.3467068
|
||||
z: -0.08820164
|
||||
}
|
||||
}
|
||||
|
||||
world_landmarks {
|
||||
landmark {
|
||||
x: -0.08568013
|
||||
y: 0.016593203
|
||||
z: 0.036527164
|
||||
}
|
||||
landmark {
|
||||
x: -0.0565372
|
||||
y: 0.041761592
|
||||
z: 0.019493781
|
||||
}
|
||||
landmark {
|
||||
x: -0.031365488
|
||||
y: 0.05031186
|
||||
z: 0.0025481891
|
||||
}
|
||||
landmark {
|
||||
x: -0.008534161
|
||||
y: 0.04286737
|
||||
z: -0.024755282
|
||||
}
|
||||
landmark {
|
||||
x: -0.0047254
|
||||
y: 0.015748458
|
||||
z: -0.035581928
|
||||
}
|
||||
landmark {
|
||||
x: 0.013083893
|
||||
y: 0.024668094
|
||||
z: 0.0035934823
|
||||
}
|
||||
landmark {
|
||||
x: 0.04149521
|
||||
y: 0.024621274
|
||||
z: -0.0030611698
|
||||
}
|
||||
landmark {
|
||||
x: 0.06257473
|
||||
y: 0.025388625
|
||||
z: -0.010340984
|
||||
}
|
||||
landmark {
|
||||
x: 0.08009179
|
||||
y: 0.023082614
|
||||
z: -0.03162942
|
||||
}
|
||||
landmark {
|
||||
x: 0.006135068
|
||||
y: 0.000696786
|
||||
z: 0.0048212176
|
||||
}
|
||||
landmark {
|
||||
x: 0.01678449
|
||||
y: 0.0067061195
|
||||
z: -0.029920919
|
||||
}
|
||||
landmark {
|
||||
x: -0.008948593
|
||||
y: 0.016808286
|
||||
z: -0.03755109
|
||||
}
|
||||
landmark {
|
||||
x: -0.01789449
|
||||
y: 0.0153161455
|
||||
z: -0.012059977
|
||||
}
|
||||
landmark {
|
||||
x: -0.0061980113
|
||||
y: -0.017872887
|
||||
z: -0.002366997
|
||||
}
|
||||
landmark {
|
||||
x: -0.004643807
|
||||
y: -0.0108282855
|
||||
z: -0.034515083
|
||||
}
|
||||
landmark {
|
||||
x: -0.027603384
|
||||
y: 0.003529715
|
||||
z: -0.033665676
|
||||
}
|
||||
landmark {
|
||||
x: -0.035679806
|
||||
y: 0.0038255951
|
||||
z: -0.008094264
|
||||
}
|
||||
landmark {
|
||||
x: -0.02957782
|
||||
y: -0.031701155
|
||||
z: -0.008180461
|
||||
}
|
||||
landmark {
|
||||
x: -0.020741666
|
||||
y: -0.02506058
|
||||
z: -0.026839724
|
||||
}
|
||||
landmark {
|
||||
x: -0.0310834
|
||||
y: -0.009496164
|
||||
z: -0.032422185
|
||||
}
|
||||
landmark {
|
||||
x: -0.037420202
|
||||
y: -0.012883307
|
||||
z: -0.017971724
|
||||
}
|
||||
}
|
223
mediapipe/tasks/testdata/vision/thumb_up_rotated_landmarks.pbtxt
vendored
Normal file
223
mediapipe/tasks/testdata/vision/thumb_up_rotated_landmarks.pbtxt
vendored
Normal file
|
@ -0,0 +1,223 @@
|
|||
classifications {
|
||||
classification {
|
||||
score: 1.0
|
||||
label: "Left"
|
||||
display_name: "Left"
|
||||
}
|
||||
}
|
||||
|
||||
landmarks {
|
||||
landmark {
|
||||
x: 0.3283601
|
||||
y: 0.63773525
|
||||
z: -3.2280354e-07
|
||||
}
|
||||
landmark {
|
||||
x: 0.46280807
|
||||
y: 0.6339767
|
||||
z: -0.06408348
|
||||
}
|
||||
landmark {
|
||||
x: 0.5831279
|
||||
y: 0.57430106
|
||||
z: -0.08583106
|
||||
}
|
||||
landmark {
|
||||
x: 0.6689471
|
||||
y: 0.49959752
|
||||
z: -0.09886064
|
||||
}
|
||||
landmark {
|
||||
x: 0.74378216
|
||||
y: 0.47357544
|
||||
z: -0.09680563
|
||||
}
|
||||
landmark {
|
||||
x: 0.5233122
|
||||
y: 0.41020474
|
||||
z: -0.038088404
|
||||
}
|
||||
landmark {
|
||||
x: 0.5296913
|
||||
y: 0.3372598
|
||||
z: -0.08874837
|
||||
}
|
||||
landmark {
|
||||
x: 0.49039274
|
||||
y: 0.43994758
|
||||
z: -0.102315836
|
||||
}
|
||||
landmark {
|
||||
x: 0.4824569
|
||||
y: 0.47969607
|
||||
z: -0.1030014
|
||||
}
|
||||
landmark {
|
||||
x: 0.4451338
|
||||
y: 0.39520803
|
||||
z: -0.02177739
|
||||
}
|
||||
landmark {
|
||||
x: 0.4410001
|
||||
y: 0.34107083
|
||||
z: -0.07294245
|
||||
}
|
||||
landmark {
|
||||
x: 0.4162798
|
||||
y: 0.46102384
|
||||
z: -0.07746907
|
||||
}
|
||||
landmark {
|
||||
x: 0.43492994
|
||||
y: 0.47154287
|
||||
z: -0.07404131
|
||||
}
|
||||
landmark {
|
||||
x: 0.37671578
|
||||
y: 0.39535576
|
||||
z: -0.016277775
|
||||
}
|
||||
landmark {
|
||||
x: 0.36978847
|
||||
y: 0.34265152
|
||||
z: -0.07346253
|
||||
}
|
||||
landmark {
|
||||
x: 0.3559884
|
||||
y: 0.44905427
|
||||
z: -0.057693005
|
||||
}
|
||||
landmark {
|
||||
x: 0.37711847
|
||||
y: 0.46414754
|
||||
z: -0.03662908
|
||||
}
|
||||
landmark {
|
||||
x: 0.3142985
|
||||
y: 0.3942253
|
||||
z: -0.0152847925
|
||||
}
|
||||
landmark {
|
||||
x: 0.30000874
|
||||
y: 0.35543376
|
||||
z: -0.046002634
|
||||
}
|
||||
landmark {
|
||||
x: 0.30002704
|
||||
y: 0.42357764
|
||||
z: -0.032671776
|
||||
}
|
||||
landmark {
|
||||
x: 0.31079838
|
||||
y: 0.44218025
|
||||
z: -0.016200554
|
||||
}
|
||||
}
|
||||
|
||||
world_landmarks {
|
||||
landmark {
|
||||
x: -0.030687196
|
||||
y: 0.0678545
|
||||
z: 0.051061403
|
||||
}
|
||||
landmark {
|
||||
x: 0.0047719833
|
||||
y: 0.06330968
|
||||
z: 0.018945374
|
||||
}
|
||||
landmark {
|
||||
x: 0.039799504
|
||||
y: 0.054109577
|
||||
z: 0.007930638
|
||||
}
|
||||
landmark {
|
||||
x: 0.069374144
|
||||
y: 0.035063196
|
||||
z: 2.2522348e-05
|
||||
}
|
||||
landmark {
|
||||
x: 0.087818466
|
||||
y: 0.018390425
|
||||
z: 0.004055788
|
||||
}
|
||||
landmark {
|
||||
x: 0.02810654
|
||||
y: 0.0043561812
|
||||
z: -0.0038672548
|
||||
}
|
||||
landmark {
|
||||
x: 0.025270049
|
||||
y: -0.0039896416
|
||||
z: -0.032991238
|
||||
}
|
||||
landmark {
|
||||
x: 0.020414166
|
||||
y: 0.006768506
|
||||
z: -0.032724563
|
||||
}
|
||||
landmark {
|
||||
x: 0.016415983
|
||||
y: 0.024563588
|
||||
z: -0.0058115427
|
||||
}
|
||||
landmark {
|
||||
x: 0.0038743173
|
||||
y: -0.0044466974
|
||||
z: 0.0024876352
|
||||
}
|
||||
landmark {
|
||||
x: 0.0041790796
|
||||
y: -0.0115309935
|
||||
z: -0.03532454
|
||||
}
|
||||
landmark {
|
||||
x: -0.0016900161
|
||||
y: 0.015519895
|
||||
z: -0.03596156
|
||||
}
|
||||
landmark {
|
||||
x: 0.004309217
|
||||
y: 0.01917039
|
||||
z: 0.003907912
|
||||
}
|
||||
landmark {
|
||||
x: -0.016969737
|
||||
y: -0.005584497
|
||||
z: 0.0034258277
|
||||
}
|
||||
landmark {
|
||||
x: -0.016737012
|
||||
y: -0.01159037
|
||||
z: -0.02876696
|
||||
}
|
||||
landmark {
|
||||
x: -0.018165365
|
||||
y: 0.01376111
|
||||
z: -0.026835402
|
||||
}
|
||||
landmark {
|
||||
x: -0.012430167
|
||||
y: 0.02064222
|
||||
z: -0.00087265146
|
||||
}
|
||||
landmark {
|
||||
x: -0.043247573
|
||||
y: 0.0011161827
|
||||
z: 0.0056269006
|
||||
}
|
||||
landmark {
|
||||
x: -0.038128495
|
||||
y: -0.011477032
|
||||
z: -0.016374081
|
||||
}
|
||||
landmark {
|
||||
x: -0.034920715
|
||||
y: 0.005510211
|
||||
z: -0.029714659
|
||||
}
|
||||
landmark {
|
||||
x: -0.03815982
|
||||
y: 0.011989757
|
||||
z: -0.014853194
|
||||
}
|
||||
}
|
56
third_party/external_files.bzl
vendored
56
third_party/external_files.bzl
vendored
|
@ -151,7 +151,7 @@ def external_files():
|
|||
http_file(
|
||||
name = "com_google_mediapipe_dummy_gesture_recognizer_task",
|
||||
sha256 = "18e54586bda33300d459ca140cd045f6daf43d897224ba215a16db3423eae18e",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665524417056146"],
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/dummy_gesture_recognizer.task?generation=1665707319890725"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
|
@ -166,12 +166,24 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_landmarks.prototxt?generation=1661875720230540"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_left_down_hand_rotated_landmarks_prototxt",
|
||||
sha256 = "a16d6cb8dd07d60f0678ddeb6a7447b73b9b03d4ddde365c8770b472205bb6cf",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_down_hand_rotated_landmarks.prototxt?generation=1666037061297507"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_left_up_hand_landmarks_prototxt",
|
||||
sha256 = "1353ba617c4f048083618587cd23a8a22115f634521c153d4e1bd1ebd4f49dd7",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_landmarks.prototxt?generation=1661875726008879"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_left_up_hand_rotated_landmarks_prototxt",
|
||||
sha256 = "a9b9789c274d48a7cb9cc10af7bc644eb2512bb934529790d0a5404726daa86a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/expected_left_up_hand_rotated_landmarks.prototxt?generation=1666037063443676"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_expected_right_down_hand_landmarks_prototxt",
|
||||
sha256 = "f281b745175aaa7f458def6cf4c89521fb56302dd61a05642b3b4a4f237ffaa3",
|
||||
|
@ -250,6 +262,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand.pbtxt?generation=1662745351291628"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_hand_detector_result_one_hand_rotated_pbtxt",
|
||||
sha256 = "ff5ca0654028d78a3380df90054273cae79abe1b7369b164063fd1d5758ec370",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_detector_result_one_hand_rotated.pbtxt?generation=1666037065601724"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_hand_detector_result_two_hands_pbtxt",
|
||||
sha256 = "2589cb08b0ee027dc24649fe597adcfa2156a21d12ea2480f83832714ebdf95f",
|
||||
|
@ -268,6 +286,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark_lite.tflite?generation=1661875766398729"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_hand_landmark_task",
|
||||
sha256 = "dd830295598e48e6bbbdf22fd9e69538fa07768106cd9ceb04d5462ca7e38c95",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/hand_landmark.task?generation=1665707323647357"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_hand_recrop_tflite",
|
||||
sha256 = "67d996ce96f9d36fe17d2693022c6da93168026ab2f028f9e2365398d8ac7d5d",
|
||||
|
@ -346,6 +370,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands.jpg?generation=1661875796949017"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_left_hands_rotated_jpg",
|
||||
sha256 = "8609c6202bca43a99bbf23fa8e687e49fa525e89481152e4c0987f46d60d7931",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/left_hands_rotated.jpg?generation=1666037068103465"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_mobilebert_embedding_with_metadata_tflite",
|
||||
sha256 = "fa47142dcc6f446168bc672f2df9605b6da5d0c0d6264e9be62870282365b95c",
|
||||
|
@ -538,6 +568,18 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_landmarks.pbtxt?generation=1665174976408451"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pointing_up_rotated_jpg",
|
||||
sha256 = "50ff66f50281207072a038e5bb6648c43f4aacbfb8204a4d2591868756aaeff1",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated.jpg?generation=1666037072219697"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pointing_up_rotated_landmarks_pbtxt",
|
||||
sha256 = "ccf67e5867094ffb6c465a4dfbf2ef1eb3f9db2465803fc25a0b84c958e050de",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/pointing_up_rotated_landmarks.pbtxt?generation=1666037074376515"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_pose_detection_tflite",
|
||||
sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85",
|
||||
|
@ -574,6 +616,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands.jpg?generation=1661875908672404"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_right_hands_rotated_jpg",
|
||||
sha256 = "b3bdf692f0d54b86c8b67e6d1286dd0078fbe6e9dfcd507b187e3bd8b398c0f9",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/right_hands_rotated.jpg?generation=1666037076873345"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_score_calibration_file_meta_json",
|
||||
sha256 = "6a3c305620371f662419a496f75be5a10caebca7803b1e99d8d5d22ba51cda94",
|
||||
|
@ -718,6 +766,12 @@ def external_files():
|
|||
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_landmarks.pbtxt?generation=1665174979747784"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_thumb_up_rotated_landmarks_pbtxt",
|
||||
sha256 = "5d0a465959cacbd201ac8dd8fc8a66c5997a172b71809b12d27296db6a28a102",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/thumb_up_rotated_landmarks.pbtxt?generation=1666037079490527"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_two_heads_16000_hz_mono_wav",
|
||||
sha256 = "a291a9c22c39bba30138a26915e154a96286ba6ca3b413053123c504a58cce3b",
|
||||
|
|
Loading…
Reference in New Issue
Block a user