Merge branch 'master' into ios-text-embedder
This commit is contained in:
commit
cd1cb87ff6
11
WORKSPACE
11
WORKSPACE
|
@ -22,21 +22,20 @@ bazel_skylib_workspace()
|
|||
load("@bazel_skylib//lib:versions.bzl", "versions")
|
||||
versions.check(minimum_bazel_version = "3.7.2")
|
||||
|
||||
# ABSL cpp library lts_2021_03_24, patch 2.
|
||||
# ABSL cpp library lts_2023_01_25.
|
||||
http_archive(
|
||||
name = "com_google_absl",
|
||||
urls = [
|
||||
"https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz",
|
||||
"https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz",
|
||||
],
|
||||
# Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved.
|
||||
patches = [
|
||||
"@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff"
|
||||
"@//third_party:com_google_absl_windows_patch.diff"
|
||||
],
|
||||
patch_args = [
|
||||
"-p1",
|
||||
],
|
||||
strip_prefix = "abseil-cpp-20220623.1",
|
||||
sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8"
|
||||
strip_prefix = "abseil-cpp-20230125.0",
|
||||
sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21"
|
||||
)
|
||||
|
||||
http_archive(
|
||||
|
|
|
@ -4,12 +4,10 @@ py_binary(
|
|||
name = "build_py_api_docs",
|
||||
srcs = ["build_py_api_docs.py"],
|
||||
deps = [
|
||||
"//mediapipe",
|
||||
"//third_party/py/absl:app",
|
||||
"//third_party/py/absl/flags",
|
||||
"//third_party/py/tensorflow_docs",
|
||||
"//third_party/py/mediapipe",
|
||||
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
|
||||
"//third_party/py/tensorflow_docs/api_generator:public_api",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ install --user six`.
|
|||
|
||||
```bash
|
||||
$ cd $HOME
|
||||
$ git clone https://github.com/google/mediapipe.git
|
||||
$ git clone -–depth 1 https://github.com/google/mediapipe.git
|
||||
|
||||
# Change directory into MediaPipe root directory
|
||||
$ cd mediapipe
|
||||
|
@ -287,7 +287,7 @@ build issues.
|
|||
2. Checkout MediaPipe repository.
|
||||
|
||||
```bash
|
||||
$ git clone https://github.com/google/mediapipe.git
|
||||
$ git clone -–depth 1 https://github.com/google/mediapipe.git
|
||||
|
||||
# Change directory into MediaPipe root directory
|
||||
$ cd mediapipe
|
||||
|
@ -416,7 +416,7 @@ build issues.
|
|||
3. Checkout MediaPipe repository.
|
||||
|
||||
```bash
|
||||
$ git clone https://github.com/google/mediapipe.git
|
||||
$ git clone -–depth 1 https://github.com/google/mediapipe.git
|
||||
|
||||
$ cd mediapipe
|
||||
```
|
||||
|
@ -590,7 +590,7 @@ next section.
|
|||
7. Checkout MediaPipe repository.
|
||||
|
||||
```
|
||||
C:\Users\Username\mediapipe_repo> git clone https://github.com/google/mediapipe.git
|
||||
C:\Users\Username\mediapipe_repo> git clone -–depth 1 https://github.com/google/mediapipe.git
|
||||
|
||||
# Change directory into MediaPipe root directory
|
||||
C:\Users\Username\mediapipe_repo> cd mediapipe
|
||||
|
@ -680,7 +680,7 @@ cameras. Alternatively, you use a video file as input.
|
|||
6. Checkout MediaPipe repository.
|
||||
|
||||
```bash
|
||||
username@DESKTOP-TMVLBJ1:~$ git clone https://github.com/google/mediapipe.git
|
||||
username@DESKTOP-TMVLBJ1:~$ git clone -–depth 1 https://github.com/google/mediapipe.git
|
||||
|
||||
username@DESKTOP-TMVLBJ1:~$ cd mediapipe
|
||||
```
|
||||
|
@ -771,7 +771,7 @@ This will use a Docker image that will isolate mediapipe's installation from the
|
|||
2. Build a docker image with tag "mediapipe".
|
||||
|
||||
```bash
|
||||
$ git clone https://github.com/google/mediapipe.git
|
||||
$ git clone -–depth 1 https://github.com/google/mediapipe.git
|
||||
$ cd mediapipe
|
||||
$ docker build --tag=mediapipe .
|
||||
|
||||
|
|
|
@ -1329,6 +1329,7 @@ cc_library(
|
|||
hdrs = ["merge_to_vector_calculator.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
|
|
@ -49,7 +49,7 @@ namespace mediapipe {
|
|||
// calculator: "EndLoopWithOutputCalculator"
|
||||
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
|
||||
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
|
||||
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
|
||||
// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts
|
||||
// }
|
||||
//
|
||||
// Input streams tagged with "CLONE" are cloned to the corresponding output
|
||||
|
|
|
@ -1054,6 +1054,7 @@ cc_test(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/util:packet_test_util",
|
||||
"@org_tensorflow//tensorflow/core:framework",
|
||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
|
|
|
@ -102,7 +102,7 @@ absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
auto output =
|
||||
absl::make_unique<std::vector<float>>(input_tensor.NumElements());
|
||||
const auto& tensor_values = input_tensor.flat<float>();
|
||||
const auto& tensor_values = input_tensor.unaligned_flat<float>();
|
||||
for (int i = 0; i < input_tensor.NumElements(); ++i) {
|
||||
output->at(i) = tensor_values(i);
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/util/packet_test_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
|
@ -129,5 +130,28 @@ TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorToVectorFloatCalculatorTest, AcceptsUnalignedTensors) {
|
||||
SetUpRunner(/*tensor_is_2d=*/false, /*flatten_nd=*/false);
|
||||
|
||||
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 5});
|
||||
tf::Tensor tensor(tf::DT_FLOAT, tensor_shape);
|
||||
auto slice = tensor.Slice(1, 1).flat<float>();
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
slice(i) = i;
|
||||
}
|
||||
|
||||
auto input_tensor = tensor.SubSlice(1);
|
||||
// Ensure that the input tensor is unaligned.
|
||||
ASSERT_FALSE(input_tensor.IsAligned());
|
||||
runner_->MutableInputs()->Index(0).packets.push_back(
|
||||
MakePacket<tf::Tensor>(input_tensor).At(Timestamp(5)));
|
||||
|
||||
ASSERT_TRUE(runner_->Run().ok());
|
||||
|
||||
EXPECT_THAT(runner_->Outputs().Index(0).packets,
|
||||
ElementsAre(PacketContainsTimestampAndPayload<std::vector<float>>(
|
||||
Timestamp(5), std::vector<float>({0, 1, 2, 3, 4}))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -1051,6 +1051,7 @@ cc_library(
|
|||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -26,12 +26,21 @@ using ::mediapipe::api2::test::FooBar1;
|
|||
|
||||
TEST(BuilderTest, BuildGraph) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
base >> foo.In("BASE");
|
||||
side >> foo.SideIn("SIDE");
|
||||
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
graph.In("IN").SetName("base") >> foo.In("BASE");
|
||||
graph.SideIn("SIDE").SetName("side") >> foo.SideIn("SIDE");
|
||||
foo.Out("OUT") >> bar.In("IN");
|
||||
bar.Out("OUT").SetName("out") >> graph.Out("OUT");
|
||||
foo_out >> bar.In("IN");
|
||||
Stream<AnyType> bar_out = bar.Out("OUT");
|
||||
|
||||
// Graph outputs.
|
||||
bar_out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -87,6 +96,7 @@ TEST(BuilderTest, CopyableStream) {
|
|||
TEST(BuilderTest, BuildGraphWithFunctions) {
|
||||
Graph graph;
|
||||
|
||||
// Graph inputs.
|
||||
Stream<int> base = graph.In("IN").SetName("base").Cast<int>();
|
||||
SidePacket<float> side = graph.SideIn("SIDE").SetName("side").Cast<float>();
|
||||
|
||||
|
@ -105,6 +115,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) {
|
|||
};
|
||||
Stream<double> bar_out = bar_fn(foo_out, graph);
|
||||
|
||||
// Graph outputs.
|
||||
bar_out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
|
@ -130,12 +141,21 @@ TEST(BuilderTest, BuildGraphWithFunctions) {
|
|||
template <class FooT>
|
||||
void BuildGraphTypedTest() {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
|
||||
|
||||
auto& foo = graph.AddNode<FooT>();
|
||||
base >> foo.In(MPP_TAG("BASE"));
|
||||
side >> foo.SideIn(MPP_TAG("BIAS"));
|
||||
Stream<float> foo_out = foo.Out(MPP_TAG("OUT"));
|
||||
|
||||
auto& bar = graph.AddNode<Bar>();
|
||||
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
|
||||
graph.SideIn("SIDE").SetName("side") >> foo.SideIn(MPP_TAG("BIAS"));
|
||||
foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN"));
|
||||
bar.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
|
||||
foo_out >> bar.In(MPP_TAG("IN"));
|
||||
Stream<AnyType> bar_out = bar.Out(MPP_TAG("OUT"));
|
||||
|
||||
// Graph outputs.
|
||||
bar_out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
|
@ -165,12 +185,20 @@ TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<test::Foo2>(); }
|
|||
|
||||
TEST(BuilderTest, FanOut) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
base >> foo.In("BASE");
|
||||
Stream<AnyType> foo_out = foo.Out("OUT");
|
||||
|
||||
auto& adder = graph.AddNode("FloatAdder");
|
||||
graph.In("IN").SetName("base") >> foo.In("BASE");
|
||||
foo.Out("OUT") >> adder.In("IN")[0];
|
||||
foo.Out("OUT") >> adder.In("IN")[1];
|
||||
adder.Out("OUT").SetName("out") >> graph.Out("OUT");
|
||||
foo_out >> adder.In("IN")[0];
|
||||
foo_out >> adder.In("IN")[1];
|
||||
Stream<AnyType> out = adder.Out("OUT");
|
||||
|
||||
// Graph outputs.
|
||||
out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -193,12 +221,20 @@ TEST(BuilderTest, FanOut) {
|
|||
|
||||
TEST(BuilderTest, TypedMultiple) {
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode<test::Foo>();
|
||||
auto& adder = graph.AddNode<test::FloatAdder>();
|
||||
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
|
||||
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0];
|
||||
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1];
|
||||
adder.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
|
||||
// Graph inputs.
|
||||
Stream<AnyType> base = graph.In("IN").SetName("base");
|
||||
|
||||
auto& foo = graph.AddNode<Foo>();
|
||||
base >> foo.In(MPP_TAG("BASE"));
|
||||
Stream<float> foo_out = foo.Out(MPP_TAG("OUT"));
|
||||
|
||||
auto& adder = graph.AddNode<FloatAdder>();
|
||||
foo_out >> adder.In(MPP_TAG("IN"))[0];
|
||||
foo_out >> adder.In(MPP_TAG("IN"))[1];
|
||||
Stream<float> out = adder.Out(MPP_TAG("OUT"));
|
||||
|
||||
// Graph outputs.
|
||||
out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -221,13 +257,20 @@ TEST(BuilderTest, TypedMultiple) {
|
|||
|
||||
TEST(BuilderTest, TypedByPorts) {
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode<test::Foo>();
|
||||
auto& adder = graph.AddNode<FloatAdder>();
|
||||
// Graph inputs.
|
||||
Stream<int> base = graph.In(FooBar1::kIn).SetName("base");
|
||||
|
||||
graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase];
|
||||
foo[Foo::kOut] >> adder[FloatAdder::kIn][0];
|
||||
foo[Foo::kOut] >> adder[FloatAdder::kIn][1];
|
||||
adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut);
|
||||
auto& foo = graph.AddNode<test::Foo>();
|
||||
base >> foo[Foo::kBase];
|
||||
Stream<float> foo_out = foo[Foo::kOut];
|
||||
|
||||
auto& adder = graph.AddNode<FloatAdder>();
|
||||
foo_out >> adder[FloatAdder::kIn][0];
|
||||
foo_out >> adder[FloatAdder::kIn][1];
|
||||
Stream<float> out = adder[FloatAdder::kOut];
|
||||
|
||||
// Graph outputs.
|
||||
out.SetName("out") >> graph.Out(FooBar1::kOut);
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -250,9 +293,15 @@ TEST(BuilderTest, TypedByPorts) {
|
|||
|
||||
TEST(BuilderTest, PacketGenerator) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
SidePacket<AnyType> side_in = graph.SideIn("IN");
|
||||
|
||||
auto& generator = graph.AddPacketGenerator("FloatGenerator");
|
||||
graph.SideIn("IN") >> generator.SideIn("IN");
|
||||
generator.SideOut("OUT") >> graph.SideOut("OUT");
|
||||
side_in >> generator.SideIn("IN");
|
||||
SidePacket<AnyType> side_out = generator.SideOut("OUT");
|
||||
|
||||
// Graph outputs.
|
||||
side_out >> graph.SideOut("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -269,12 +318,21 @@ TEST(BuilderTest, PacketGenerator) {
|
|||
|
||||
TEST(BuilderTest, EmptyTag) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> a = graph.In("A").SetName("a");
|
||||
Stream<AnyType> c = graph.In("C").SetName("c");
|
||||
Stream<AnyType> b = graph.In("B").SetName("b");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
graph.In("A").SetName("a") >> foo.In("")[0];
|
||||
graph.In("C").SetName("c") >> foo.In("")[2];
|
||||
graph.In("B").SetName("b") >> foo.In("")[1];
|
||||
foo.Out("")[0].SetName("x") >> graph.Out("ONE");
|
||||
foo.Out("")[1].SetName("y") >> graph.Out("TWO");
|
||||
a >> foo.In("")[0];
|
||||
c >> foo.In("")[2];
|
||||
b >> foo.In("")[1];
|
||||
Stream<AnyType> x = foo.Out("")[0];
|
||||
Stream<AnyType> y = foo.Out("")[1];
|
||||
|
||||
// Graph outputs.
|
||||
x.SetName("x") >> graph.Out("ONE");
|
||||
y.SetName("y") >> graph.Out("TWO");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -301,10 +359,17 @@ TEST(BuilderTest, StringLikeTags) {
|
|||
constexpr absl::string_view kC = "C";
|
||||
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> a = graph.In(kA).SetName("a");
|
||||
Stream<AnyType> b = graph.In(kB).SetName("b");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
graph.In(kA).SetName("a") >> foo.In(kA);
|
||||
graph.In(kB).SetName("b") >> foo.In(kB);
|
||||
foo.Out(kC).SetName("c") >> graph.Out(kC);
|
||||
a >> foo.In(kA);
|
||||
b >> foo.In(kB);
|
||||
Stream<AnyType> c = foo.Out(kC);
|
||||
|
||||
// Graph outputs.
|
||||
c.SetName("c") >> graph.Out(kC);
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -323,12 +388,21 @@ TEST(BuilderTest, StringLikeTags) {
|
|||
|
||||
TEST(BuilderTest, GraphIndexes) {
|
||||
Graph graph;
|
||||
// Graph inputs.
|
||||
Stream<AnyType> a = graph.In(0).SetName("a");
|
||||
Stream<AnyType> c = graph.In(1).SetName("c");
|
||||
Stream<AnyType> b = graph.In(2).SetName("b");
|
||||
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
graph.In(0).SetName("a") >> foo.In("")[0];
|
||||
graph.In(1).SetName("c") >> foo.In("")[2];
|
||||
graph.In(2).SetName("b") >> foo.In("")[1];
|
||||
foo.Out("")[0].SetName("x") >> graph.Out(1);
|
||||
foo.Out("")[1].SetName("y") >> graph.Out(0);
|
||||
a >> foo.In("")[0];
|
||||
c >> foo.In("")[2];
|
||||
b >> foo.In("")[1];
|
||||
Stream<AnyType> x = foo.Out("")[0];
|
||||
Stream<AnyType> y = foo.Out("")[1];
|
||||
|
||||
// Graph outputs.
|
||||
x.SetName("x") >> graph.Out(1);
|
||||
y.SetName("y") >> graph.Out(0);
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -381,21 +455,20 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
|
|||
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
|
||||
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
|
||||
int_input >> node[AnyAndSameTypeCalculator::kIntInput];
|
||||
|
||||
Stream<AnyType> any_type_output =
|
||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput];
|
||||
any_type_output.SetName("any_type_output");
|
||||
|
||||
Stream<AnyType> same_type_output =
|
||||
node[AnyAndSameTypeCalculator::kSameTypeOutput];
|
||||
same_type_output.SetName("same_type_output");
|
||||
Stream<AnyType> recursive_same_type_output =
|
||||
node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput];
|
||||
recursive_same_type_output.SetName("recursive_same_type_output");
|
||||
Stream<int> same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput];
|
||||
same_int_output.SetName("same_int_output");
|
||||
Stream<int> recursive_same_int_type_output =
|
||||
node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput];
|
||||
|
||||
any_type_output.SetName("any_type_output");
|
||||
same_type_output.SetName("same_type_output");
|
||||
recursive_same_type_output.SetName("recursive_same_type_output");
|
||||
same_int_output.SetName("same_int_output");
|
||||
recursive_same_int_type_output.SetName("recursive_same_int_type_output");
|
||||
|
||||
CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie<
|
||||
|
@ -424,11 +497,10 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
|
||||
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
|
||||
Stream<double> any_type_output =
|
||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput]
|
||||
.SetName("any_type_output")
|
||||
.Cast<double>();
|
||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||
|
||||
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||
any_type_output.SetName("any_type_output") >>
|
||||
graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
|
|
@ -486,3 +486,15 @@ cc_test(
|
|||
"//mediapipe/gpu:disable_gpu": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "frame_buffer",
|
||||
srcs = ["frame_buffer.cc"],
|
||||
hdrs = ["frame_buffer.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
|
176
mediapipe/framework/formats/frame_buffer.cc
Normal file
176
mediapipe/framework/formats/frame_buffer.cc
Normal file
|
@ -0,0 +1,176 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/framework/formats/frame_buffer.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns whether the input `format` is a supported YUV format.
|
||||
bool IsSupportedYuvFormat(FrameBuffer::Format format) {
|
||||
return format == FrameBuffer::Format::kNV21 ||
|
||||
format == FrameBuffer::Format::kNV12 ||
|
||||
format == FrameBuffer::Format::kYV12 ||
|
||||
format == FrameBuffer::Format::kYV21;
|
||||
}
|
||||
|
||||
// Returns supported 1-plane FrameBuffer in YuvData structure.
|
||||
absl::StatusOr<FrameBuffer::YuvData> GetYuvDataFromOnePlaneFrameBuffer(
|
||||
const FrameBuffer& source) {
|
||||
if (!IsSupportedYuvFormat(source.format())) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The source FrameBuffer format is not part of YUV420 family.");
|
||||
}
|
||||
|
||||
FrameBuffer::YuvData result;
|
||||
const int y_buffer_size =
|
||||
source.plane(0).stride.row_stride_bytes * source.dimension().height;
|
||||
const int uv_buffer_size =
|
||||
((source.plane(0).stride.row_stride_bytes + 1) / 2) *
|
||||
((source.dimension().height + 1) / 2);
|
||||
result.y_buffer = source.plane(0).buffer;
|
||||
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
|
||||
result.uv_row_stride = result.y_row_stride;
|
||||
|
||||
if (source.format() == FrameBuffer::Format::kNV21) {
|
||||
result.v_buffer = result.y_buffer + y_buffer_size;
|
||||
result.u_buffer = result.v_buffer + 1;
|
||||
result.uv_pixel_stride = 2;
|
||||
// If y_row_stride equals to the frame width and is an odd value,
|
||||
// uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride.
|
||||
if (result.y_row_stride == source.dimension().width &&
|
||||
result.y_row_stride % 2 == 1) {
|
||||
result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2;
|
||||
}
|
||||
} else if (source.format() == FrameBuffer::Format::kNV12) {
|
||||
result.u_buffer = result.y_buffer + y_buffer_size;
|
||||
result.v_buffer = result.u_buffer + 1;
|
||||
result.uv_pixel_stride = 2;
|
||||
// If y_row_stride equals to the frame width and is an odd value,
|
||||
// uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride.
|
||||
if (result.y_row_stride == source.dimension().width &&
|
||||
result.y_row_stride % 2 == 1) {
|
||||
result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2;
|
||||
}
|
||||
} else if (source.format() == FrameBuffer::Format::kYV21) {
|
||||
result.u_buffer = result.y_buffer + y_buffer_size;
|
||||
result.v_buffer = result.u_buffer + uv_buffer_size;
|
||||
result.uv_pixel_stride = 1;
|
||||
result.uv_row_stride = (result.y_row_stride + 1) / 2;
|
||||
} else if (source.format() == FrameBuffer::Format::kYV12) {
|
||||
result.v_buffer = result.y_buffer + y_buffer_size;
|
||||
result.u_buffer = result.v_buffer + uv_buffer_size;
|
||||
result.uv_pixel_stride = 1;
|
||||
result.uv_row_stride = (result.y_row_stride + 1) / 2;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns supported 2-plane FrameBuffer in YuvData structure.
|
||||
absl::StatusOr<FrameBuffer::YuvData> GetYuvDataFromTwoPlaneFrameBuffer(
|
||||
const FrameBuffer& source) {
|
||||
if (source.format() != FrameBuffer::Format::kNV12 &&
|
||||
source.format() != FrameBuffer::Format::kNV21) {
|
||||
return absl::InvalidArgumentError("Unsupported YUV planar format.");
|
||||
}
|
||||
|
||||
FrameBuffer::YuvData result;
|
||||
// Y plane
|
||||
result.y_buffer = source.plane(0).buffer;
|
||||
// All plane strides
|
||||
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
|
||||
result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
|
||||
result.uv_pixel_stride = 2;
|
||||
|
||||
if (source.format() == FrameBuffer::Format::kNV12) {
|
||||
// Y and UV interleaved format
|
||||
result.u_buffer = source.plane(1).buffer;
|
||||
result.v_buffer = result.u_buffer + 1;
|
||||
} else {
|
||||
// Y and VU interleaved format
|
||||
result.v_buffer = source.plane(1).buffer;
|
||||
result.u_buffer = result.v_buffer + 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns supported 3-plane FrameBuffer in YuvData structure. Note that NV21
|
||||
// and NV12 are included in the supported Yuv formats. Technically, NV21 and
|
||||
// NV12 should not be described by the 3-plane format. Historically, NV21 is
|
||||
// used loosely such that it can also be used to describe YV21 format. For
|
||||
// backwards compatibility, FrameBuffer supports NV21/NV12 with 3-plane format
|
||||
// but such usage is discouraged
|
||||
absl::StatusOr<FrameBuffer::YuvData> GetYuvDataFromThreePlaneFrameBuffer(
|
||||
const FrameBuffer& source) {
|
||||
if (!IsSupportedYuvFormat(source.format())) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The source FrameBuffer format is not part of YUV420 family.");
|
||||
}
|
||||
|
||||
if (source.plane(1).stride.row_stride_bytes !=
|
||||
source.plane(2).stride.row_stride_bytes ||
|
||||
source.plane(1).stride.pixel_stride_bytes !=
|
||||
source.plane(2).stride.pixel_stride_bytes) {
|
||||
return absl::InternalError("Unsupported YUV planar format.");
|
||||
}
|
||||
FrameBuffer::YuvData result;
|
||||
if (source.format() == FrameBuffer::Format::kNV21 ||
|
||||
source.format() == FrameBuffer::Format::kYV12) {
|
||||
// Y follow by VU order. The VU chroma planes can be interleaved or
|
||||
// planar.
|
||||
result.y_buffer = source.plane(0).buffer;
|
||||
result.v_buffer = source.plane(1).buffer;
|
||||
result.u_buffer = source.plane(2).buffer;
|
||||
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
|
||||
result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
|
||||
result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes;
|
||||
} else {
|
||||
// Y follow by UV order. The UV chroma planes can be interleaved or
|
||||
// planar.
|
||||
result.y_buffer = source.plane(0).buffer;
|
||||
result.u_buffer = source.plane(1).buffer;
|
||||
result.v_buffer = source.plane(2).buffer;
|
||||
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
|
||||
result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
|
||||
result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<FrameBuffer::YuvData> FrameBuffer::GetYuvDataFromFrameBuffer(
|
||||
const FrameBuffer& source) {
|
||||
if (!IsSupportedYuvFormat(source.format())) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The source FrameBuffer format is not part of YUV420 family.");
|
||||
}
|
||||
|
||||
if (source.plane_count() == 1) {
|
||||
return GetYuvDataFromOnePlaneFrameBuffer(source);
|
||||
} else if (source.plane_count() == 2) {
|
||||
return GetYuvDataFromTwoPlaneFrameBuffer(source);
|
||||
} else if (source.plane_count() == 3) {
|
||||
return GetYuvDataFromThreePlaneFrameBuffer(source);
|
||||
}
|
||||
return absl::InvalidArgumentError(
|
||||
"The source FrameBuffer must be consisted by 1, 2, or 3 planes");
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
246
mediapipe/framework/formats/frame_buffer.h
Normal file
246
mediapipe/framework/formats/frame_buffer.h
Normal file
|
@ -0,0 +1,246 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera
|
||||
// frame or still image) with buffer format information. FrameBuffer doesn't
|
||||
// take ownership of the provided backing buffer. The caller is responsible to
|
||||
// manage the backing buffer lifecycle for the lifetime of the FrameBuffer.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// // Create an metadata instance with no backing buffer.
|
||||
// auto buffer = FrameBuffer::Create(/*planes=*/{}, dimension, kRGBA,
|
||||
// KTopLeft);
|
||||
//
|
||||
// // Create an RGBA instance with backing buffer on single plane.
|
||||
// FrameBuffer::Plane plane =
|
||||
// {rgba_buffer, /*stride=*/{dimension.width * 4, 4}};
|
||||
// auto buffer = FrameBuffer::Create({plane}, dimension, kRGBA, kTopLeft);
|
||||
//
|
||||
// // Create an YUV instance with planar backing buffer.
|
||||
// FrameBuffer::Plane y_plane = {y_buffer, /*stride=*/{dimension.width , 1}};
|
||||
// FrameBuffer::Plane uv_plane = {u_buffer, /*stride=*/{dimension.width, 2}};
|
||||
// auto buffer = FrameBuffer::Create({y_plane, uv_plane}, dimension, kNV21,
|
||||
// kLeftTop);
|
||||
class FrameBuffer {
|
||||
public:
|
||||
// Colorspace formats.
|
||||
enum class Format {
|
||||
kRGBA,
|
||||
kRGB,
|
||||
kNV12,
|
||||
kNV21,
|
||||
kYV12,
|
||||
kYV21,
|
||||
kGRAY,
|
||||
kUNKNOWN
|
||||
};
|
||||
|
||||
// Stride information.
|
||||
struct Stride {
|
||||
// The row stride in bytes. This is the distance between the start pixels of
|
||||
// two consecutive rows in the image.
|
||||
int row_stride_bytes;
|
||||
// This is the distance between two consecutive pixel values in a row of
|
||||
// pixels in bytes. It may be larger than the size of a single pixel to
|
||||
// account for interleaved image data or padded formats.
|
||||
int pixel_stride_bytes;
|
||||
|
||||
bool operator==(const Stride& other) const {
|
||||
return row_stride_bytes == other.row_stride_bytes &&
|
||||
pixel_stride_bytes == other.pixel_stride_bytes;
|
||||
}
|
||||
|
||||
bool operator!=(const Stride& other) const { return !operator==(other); }
|
||||
};
|
||||
|
||||
// YUV data structure.
|
||||
struct YuvData {
|
||||
const uint8* y_buffer;
|
||||
const uint8* u_buffer;
|
||||
const uint8* v_buffer;
|
||||
// Y buffer row stride in bytes.
|
||||
int y_row_stride;
|
||||
// U/V buffer row stride in bytes.
|
||||
int uv_row_stride;
|
||||
// U/V pixel stride in bytes. This is the distance between two consecutive
|
||||
// u/v pixel values in a row.
|
||||
int uv_pixel_stride;
|
||||
};
|
||||
|
||||
// FrameBuffer content orientation follows EXIF specification. The name of
|
||||
// each enum value defines the position of the 0th row and the 0th column of
|
||||
// the image content. See http://jpegclub.org/exif_orientation.html for
|
||||
// details.
|
||||
enum class Orientation {
|
||||
kTopLeft = 1,
|
||||
kTopRight = 2,
|
||||
kBottomRight = 3,
|
||||
kBottomLeft = 4,
|
||||
kLeftTop = 5,
|
||||
kRightTop = 6,
|
||||
kRightBottom = 7,
|
||||
kLeftBottom = 8
|
||||
};
|
||||
|
||||
// Plane encapsulates buffer and stride information.
|
||||
struct Plane {
|
||||
const uint8* buffer;
|
||||
Stride stride;
|
||||
};
|
||||
|
||||
// Dimension information for the whole frame or a cropped portion of it.
|
||||
struct Dimension {
|
||||
// The width dimension in pixel unit.
|
||||
int width;
|
||||
// The height dimension in pixel unit.
|
||||
int height;
|
||||
|
||||
bool operator==(const Dimension& other) const {
|
||||
return width == other.width && height == other.height;
|
||||
}
|
||||
|
||||
bool operator!=(const Dimension& other) const {
|
||||
return width != other.width || height != other.height;
|
||||
}
|
||||
|
||||
bool operator>=(const Dimension& other) const {
|
||||
return width >= other.width && height >= other.height;
|
||||
}
|
||||
|
||||
bool operator<=(const Dimension& other) const {
|
||||
return width <= other.width && height <= other.height;
|
||||
}
|
||||
|
||||
// Swaps width and height.
|
||||
void Swap() {
|
||||
using std::swap;
|
||||
swap(width, height);
|
||||
}
|
||||
|
||||
// Returns area represented by width * height.
|
||||
int Size() const { return width * height; }
|
||||
};
|
||||
|
||||
// Factory method for creating a FrameBuffer object from row-major backing
|
||||
// buffers.
|
||||
static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
|
||||
Dimension dimension, Format format,
|
||||
Orientation orientation) {
|
||||
return absl::make_unique<FrameBuffer>(planes, dimension, format,
|
||||
orientation);
|
||||
}
|
||||
|
||||
// Factory method for creating a FrameBuffer object from row-major movable
|
||||
// backing buffers.
|
||||
static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
|
||||
Dimension dimension, Format format,
|
||||
Orientation orientation) {
|
||||
return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
|
||||
orientation);
|
||||
}
|
||||
|
||||
// Returns YuvData which contains the Y, U, and V buffer and their
|
||||
// stride info from the input `source` FrameBuffer which is in the YUV family
|
||||
// formats (e.g NV12, NV21, YV12, and YV21).
|
||||
static absl::StatusOr<YuvData> GetYuvDataFromFrameBuffer(
|
||||
const FrameBuffer& source);
|
||||
|
||||
// Builds a FrameBuffer object from a row-major backing buffer.
|
||||
//
|
||||
// The FrameBuffer does not take ownership of the backing buffer. The backing
|
||||
// buffer is read-only and the caller is responsible for maintaining the
|
||||
// backing buffer lifecycle for the lifetime of FrameBuffer.
|
||||
FrameBuffer(const std::vector<Plane>& planes, Dimension dimension,
|
||||
Format format, Orientation orientation)
|
||||
: planes_(planes),
|
||||
dimension_(dimension),
|
||||
format_(format),
|
||||
orientation_(orientation) {}
|
||||
|
||||
// Builds a FrameBuffer object from a movable row-major backing buffer.
|
||||
//
|
||||
// The FrameBuffer does not take ownership of the backing buffer. The backing
|
||||
// buffer is read-only and the caller is responsible for maintaining the
|
||||
// backing buffer lifecycle for the lifetime of FrameBuffer.
|
||||
FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format,
|
||||
Orientation orientation)
|
||||
: planes_(std::move(planes)),
|
||||
dimension_(dimension),
|
||||
format_(format),
|
||||
orientation_(orientation) {}
|
||||
|
||||
// Copy constructor.
|
||||
//
|
||||
// FrameBuffer does not take ownership of the backing buffer. The copy
|
||||
// constructor behaves the same way to only copy the buffer pointer and not
|
||||
// take ownership of the backing buffer.
|
||||
FrameBuffer(const FrameBuffer& frame_buffer) {
|
||||
planes_.clear();
|
||||
for (int i = 0; i < frame_buffer.plane_count(); i++) {
|
||||
planes_.push_back(
|
||||
FrameBuffer::Plane{.buffer = frame_buffer.plane(i).buffer,
|
||||
.stride = frame_buffer.plane(i).stride});
|
||||
}
|
||||
dimension_ = frame_buffer.dimension();
|
||||
format_ = frame_buffer.format();
|
||||
orientation_ = frame_buffer.orientation();
|
||||
}
|
||||
|
||||
// Returns number of planes.
|
||||
int plane_count() const { return planes_.size(); }
|
||||
|
||||
// Returns plane indexed by the input `index`.
|
||||
Plane plane(int index) const {
|
||||
if (index > -1 && static_cast<size_t>(index) < planes_.size()) {
|
||||
return planes_[index];
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
// Returns FrameBuffer dimension.
|
||||
Dimension dimension() const { return dimension_; }
|
||||
|
||||
// Returns FrameBuffer format.
|
||||
Format format() const { return format_; }
|
||||
|
||||
// Returns FrameBuffer orientation.
|
||||
Orientation orientation() const { return orientation_; }
|
||||
|
||||
private:
|
||||
std::vector<Plane> planes_;
|
||||
Dimension dimension_;
|
||||
Format format_;
|
||||
Orientation orientation_;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_
|
|
@ -92,7 +92,7 @@ bool GraphRegistry::IsRegistered(const std::string& ns,
|
|||
}
|
||||
|
||||
absl::StatusOr<CalculatorGraphConfig> GraphRegistry::CreateByName(
|
||||
const std::string& ns, const std::string& type_name,
|
||||
absl::string_view ns, absl::string_view type_name,
|
||||
SubgraphContext* context) const {
|
||||
absl::StatusOr<std::unique_ptr<Subgraph>> maker =
|
||||
local_factories_.IsRegistered(ns, type_name)
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "absl/base/macros.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/deps/registration.h"
|
||||
|
@ -187,7 +188,7 @@ class GraphRegistry {
|
|||
|
||||
// Returns the specified graph config.
|
||||
absl::StatusOr<CalculatorGraphConfig> CreateByName(
|
||||
const std::string& ns, const std::string& type_name,
|
||||
absl::string_view ns, absl::string_view type_name,
|
||||
SubgraphContext* context = nullptr) const;
|
||||
|
||||
static GraphRegistry global_graph_registry;
|
||||
|
|
|
@ -441,6 +441,21 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_buffer_storage_yuv_image",
|
||||
srcs = ["gpu_buffer_storage_yuv_image.cc"],
|
||||
hdrs = ["gpu_buffer_storage_yuv_image.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage",
|
||||
"//mediapipe/framework/formats:yuv_image",
|
||||
"//third_party/libyuv",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_buffer_storage_ahwb",
|
||||
srcs = ["gpu_buffer_storage_ahwb.cc"],
|
||||
|
@ -1187,3 +1202,17 @@ mediapipe_cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_cc_test(
|
||||
name = "gpu_buffer_storage_yuv_image_test",
|
||||
size = "small",
|
||||
srcs = ["gpu_buffer_storage_yuv_image_test.cc"],
|
||||
exclude_platforms = [
|
||||
"ios",
|
||||
],
|
||||
deps = [
|
||||
":gpu_buffer_format",
|
||||
":gpu_buffer_storage_yuv_image",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -212,6 +212,10 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
|
|||
case GpuBufferFormat::kTwoComponentHalf16:
|
||||
case GpuBufferFormat::kRGBAHalf64:
|
||||
case GpuBufferFormat::kRGBAFloat128:
|
||||
case GpuBufferFormat::kNV12:
|
||||
case GpuBufferFormat::kNV21:
|
||||
case GpuBufferFormat::kI420:
|
||||
case GpuBufferFormat::kYV12:
|
||||
case GpuBufferFormat::kUnknown:
|
||||
return ImageFormat::UNKNOWN;
|
||||
}
|
||||
|
|
|
@ -52,6 +52,14 @@ enum class GpuBufferFormat : uint32_t {
|
|||
kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible.
|
||||
kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'),
|
||||
kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'),
|
||||
// 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling.
|
||||
kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'),
|
||||
// 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling.
|
||||
kNV21 = MEDIAPIPE_FOURCC('N', 'V', '2', '1'),
|
||||
// 8-bit Y plane + non-interleaved 8-bit U/V planes with 2x2 subsampling.
|
||||
kI420 = MEDIAPIPE_FOURCC('I', '4', '2', '0'),
|
||||
// 8-bit Y plane + non-interleaved 8-bit V/U planes with 2x2 subsampling.
|
||||
kYV12 = MEDIAPIPE_FOURCC('Y', 'V', '1', '2'),
|
||||
};
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -111,6 +119,10 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
|
|||
return kCVPixelFormatType_64RGBAHalf;
|
||||
case GpuBufferFormat::kRGBAFloat128:
|
||||
return kCVPixelFormatType_128RGBAFloat;
|
||||
case GpuBufferFormat::kNV12:
|
||||
case GpuBufferFormat::kNV21:
|
||||
case GpuBufferFormat::kI420:
|
||||
case GpuBufferFormat::kYV12:
|
||||
case GpuBufferFormat::kUnknown:
|
||||
return -1;
|
||||
}
|
||||
|
|
|
@ -158,7 +158,7 @@ public class GraphTextureFrame implements TextureFrame {
|
|||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
if (refCount >= 0 || nativeBufferHandle != 0) {
|
||||
if (refCount > 0 || nativeBufferHandle != 0) {
|
||||
logger.atWarning().log("release was not called before finalize");
|
||||
}
|
||||
if (!activeConsumerContextHandleSet.isEmpty()) {
|
||||
|
|
|
@ -199,6 +199,28 @@ public final class PacketGetter {
|
|||
return nativeGetImageData(packet.getNativeHandle(), buffer);
|
||||
}
|
||||
|
||||
/** Returns the size of Image list. This helps to determine size of allocated ByteBuffer array. */
|
||||
public static int getImageListSize(final Packet packet) {
|
||||
return nativeGetImageListSize(packet.getNativeHandle());
|
||||
}
|
||||
|
||||
/**
|
||||
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
|
||||
* array has the the same size of image list packet, and assumes the output buffer stores pixels
|
||||
* contiguously. It returns false if this assumption does not hold.
|
||||
*
|
||||
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of
|
||||
* ByteBuffer to copy image data to. If false, the ByteBuffer will wrap the memory address of
|
||||
* MediaPipe ImageFrame of graph output, and the ByteBuffer data is available only when MediaPipe
|
||||
* graph is alive.
|
||||
*
|
||||
* <p>Note: this function does not assume the pixel format.
|
||||
*/
|
||||
public static boolean getImageList(
|
||||
final Packet packet, ByteBuffer[] buffersArray, boolean deepCopy) {
|
||||
return nativeGetImageList(packet.getNativeHandle(), buffersArray, deepCopy);
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an RGB mediapipe image frame packet to an RGBA Byte buffer.
|
||||
*
|
||||
|
@ -316,7 +338,8 @@ public final class PacketGetter {
|
|||
public static GraphTextureFrame getTextureFrameDeferredSync(final Packet packet) {
|
||||
return new GraphTextureFrame(
|
||||
nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ false),
|
||||
packet.getTimestamp(), /* deferredSync= */true);
|
||||
packet.getTimestamp(),
|
||||
/* deferredSync= */ true);
|
||||
}
|
||||
|
||||
private static native long nativeGetPacketFromReference(long nativePacketHandle);
|
||||
|
@ -363,6 +386,11 @@ public final class PacketGetter {
|
|||
|
||||
private static native boolean nativeGetImageData(long nativePacketHandle, ByteBuffer buffer);
|
||||
|
||||
private static native int nativeGetImageListSize(long nativePacketHandle);
|
||||
|
||||
private static native boolean nativeGetImageList(
|
||||
long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy);
|
||||
|
||||
private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer);
|
||||
// Retrieves the values that are in the VideoHeader.
|
||||
private static native int nativeGetVideoHeaderWidth(long nativepackethandle);
|
||||
|
|
|
@ -50,7 +50,10 @@ public class ByteBufferExtractor {
|
|||
switch (container.getImageProperties().getStorageType()) {
|
||||
case MPImage.STORAGE_TYPE_BYTEBUFFER:
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
return byteBufferImageContainer
|
||||
.getByteBuffer()
|
||||
.asReadOnlyBuffer()
|
||||
.order(ByteOrder.nativeOrder());
|
||||
default:
|
||||
throw new IllegalArgumentException(
|
||||
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
|
||||
|
@ -74,7 +77,7 @@ public class ByteBufferExtractor {
|
|||
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
|
||||
* conversions.
|
||||
*/
|
||||
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
|
||||
public static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
|
||||
MPImageContainer container;
|
||||
MPImageProperties byteBufferProperties =
|
||||
MPImageProperties.builder()
|
||||
|
@ -83,12 +86,16 @@ public class ByteBufferExtractor {
|
|||
.build();
|
||||
if ((container = image.getContainer(byteBufferProperties)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
|
||||
return byteBufferImageContainer
|
||||
.getByteBuffer()
|
||||
.asReadOnlyBuffer()
|
||||
.order(ByteOrder.nativeOrder());
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
|
||||
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
|
||||
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
|
||||
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
|
||||
.asReadOnlyBuffer();
|
||||
.asReadOnlyBuffer()
|
||||
.order(ByteOrder.nativeOrder());
|
||||
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
|
||||
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
|
||||
ByteBuffer byteBuffer =
|
||||
|
|
|
@ -67,6 +67,8 @@ public class MPImage implements Closeable {
|
|||
IMAGE_FORMAT_YUV_420_888,
|
||||
IMAGE_FORMAT_ALPHA,
|
||||
IMAGE_FORMAT_JPEG,
|
||||
IMAGE_FORMAT_VEC32F1,
|
||||
IMAGE_FORMAT_VEC32F2,
|
||||
})
|
||||
@Retention(RetentionPolicy.SOURCE)
|
||||
public @interface MPImageFormat {}
|
||||
|
@ -81,6 +83,8 @@ public class MPImage implements Closeable {
|
|||
public static final int IMAGE_FORMAT_YUV_420_888 = 7;
|
||||
public static final int IMAGE_FORMAT_ALPHA = 8;
|
||||
public static final int IMAGE_FORMAT_JPEG = 9;
|
||||
public static final int IMAGE_FORMAT_VEC32F1 = 10;
|
||||
public static final int IMAGE_FORMAT_VEC32F2 = 11;
|
||||
|
||||
/** Specifies the image container type. Would be useful for choosing extractors. */
|
||||
@IntDef({
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
|
@ -39,6 +40,52 @@ template <typename T>
|
|||
const T& GetFromNativeHandle(int64_t packet_handle) {
|
||||
return mediapipe::android::Graph::GetPacketFromHandle(packet_handle).Get<T>();
|
||||
}
|
||||
|
||||
bool CopyImageDataToByteBuffer(JNIEnv* env, const mediapipe::ImageFrame& image,
|
||||
jobject byte_buffer) {
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
|
||||
if (buffer_data == nullptr || buffer_size < 0) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(
|
||||
"input buffer does not support direct access"));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Assume byte buffer stores pixel data contiguously.
|
||||
const int expected_buffer_size = image.Width() * image.Height() *
|
||||
image.ByteDepth() * image.NumberOfChannels();
|
||||
if (buffer_size != expected_buffer_size) {
|
||||
ThrowIfError(
|
||||
env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Expected buffer size ", expected_buffer_size,
|
||||
" got: ", buffer_size, ", width ", image.Width(), ", height ",
|
||||
image.Height(), ", channels ", image.NumberOfChannels())));
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (image.ByteDepth()) {
|
||||
case 1: {
|
||||
uint8* data = static_cast<uint8*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
uint16* data = static_cast<uint16*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
float* data = static_cast<float*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetPacketFromReference)(
|
||||
|
@ -298,46 +345,51 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
|
|||
.GetImageFrameSharedPtr()
|
||||
.get()
|
||||
: GetFromNativeHandle<mediapipe::ImageFrame>(packet);
|
||||
return CopyImageDataToByteBuffer(env, image, byte_buffer);
|
||||
}
|
||||
|
||||
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
|
||||
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
|
||||
if (buffer_data == nullptr || buffer_size < 0) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(
|
||||
"input buffer does not support direct access"));
|
||||
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)(
|
||||
JNIEnv* env, jobject thiz, jlong packet) {
|
||||
const auto& image_list =
|
||||
GetFromNativeHandle<std::vector<mediapipe::Image>>(packet);
|
||||
return image_list.size();
|
||||
}
|
||||
|
||||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
|
||||
JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array,
|
||||
jboolean deep_copy) {
|
||||
const auto& image_list =
|
||||
GetFromNativeHandle<std::vector<mediapipe::Image>>(packet);
|
||||
if (env->GetArrayLength(byte_buffer_array) != image_list.size()) {
|
||||
ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Expected ByteBuffer array size: ", image_list.size(),
|
||||
" but get ByteBuffer array size: ",
|
||||
env->GetArrayLength(byte_buffer_array))));
|
||||
return false;
|
||||
}
|
||||
|
||||
// Assume byte buffer stores pixel data contiguously.
|
||||
const int expected_buffer_size = image.Width() * image.Height() *
|
||||
image.ByteDepth() * image.NumberOfChannels();
|
||||
if (buffer_size != expected_buffer_size) {
|
||||
ThrowIfError(
|
||||
env, absl::InvalidArgumentError(absl::StrCat(
|
||||
"Expected buffer size ", expected_buffer_size,
|
||||
" got: ", buffer_size, ", width ", image.Width(), ", height ",
|
||||
image.Height(), ", channels ", image.NumberOfChannels())));
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (image.ByteDepth()) {
|
||||
case 1: {
|
||||
uint8* data = static_cast<uint8*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
uint16* data = static_cast<uint16*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
float* data = static_cast<float*>(buffer_data);
|
||||
image.CopyToBuffer(data, expected_buffer_size);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
for (int i = 0; i < image_list.size(); ++i) {
|
||||
auto& image = *image_list[i].GetImageFrameSharedPtr().get();
|
||||
if (!image.IsContiguous()) {
|
||||
ThrowIfError(
|
||||
env, absl::InternalError("ImageFrame must store data contiguously to "
|
||||
"be allocated as ByteBuffer."));
|
||||
return false;
|
||||
}
|
||||
if (deep_copy) {
|
||||
jobject byte_buffer = reinterpret_cast<jobject>(
|
||||
env->GetObjectArrayElement(byte_buffer_array, i));
|
||||
if (!CopyImageDataToByteBuffer(env, image, byte_buffer)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Assume byte buffer stores pixel data contiguously.
|
||||
const int expected_buffer_size = image.Width() * image.Height() *
|
||||
image.ByteDepth() *
|
||||
image.NumberOfChannels();
|
||||
jobject image_data_byte_buffer = env->NewDirectByteBuffer(
|
||||
image.MutablePixelData(), expected_buffer_size);
|
||||
env->SetObjectArrayElement(byte_buffer_array, i, image_data_byte_buffer);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -415,7 +467,8 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetAudioData)(
|
|||
int16 value =
|
||||
static_cast<int16>(audio_mat(channel, sample) * kMultiplier);
|
||||
// The java and native has the same byte order, by default is little
|
||||
// Endian, we can safely copy data directly, we have tests to cover this.
|
||||
// Endian, we can safely copy data directly, we have tests to cover
|
||||
// this.
|
||||
env->SetByteArrayRegion(byte_data, offset, 2,
|
||||
reinterpret_cast<const jbyte*>(&value));
|
||||
offset += 2;
|
||||
|
|
|
@ -106,6 +106,17 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageHeight)(JNIEnv* env,
|
|||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
|
||||
JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer);
|
||||
|
||||
// Return the vector size of std::vector<Image>.
|
||||
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)(
|
||||
JNIEnv* env, jobject thiz, jlong packet);
|
||||
|
||||
// Fill ByteBuffer[] from the Packet of std::vector<Image>.
|
||||
// Before calling this, the byte_buffer_array needs to have the correct
|
||||
// allocated size.
|
||||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
|
||||
JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array,
|
||||
jboolean deep_copy);
|
||||
|
||||
// Before calling this, the byte_buffer needs to have the correct allocated
|
||||
// size.
|
||||
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(
|
||||
|
|
|
@ -61,6 +61,7 @@ py_test(
|
|||
name = "file_util_test",
|
||||
srcs = ["file_util_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [":file_util"],
|
||||
)
|
||||
|
||||
|
|
|
@ -13,11 +13,93 @@
|
|||
# limitations under the License.
|
||||
"""Utilities for files."""
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import requests
|
||||
|
||||
# resources dependency
|
||||
|
||||
|
||||
_TEMPDIR_FOLDER = 'model_maker'
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DownloadedFiles:
|
||||
"""File(s) that are downloaded from a url into a local directory.
|
||||
|
||||
If `is_folder` is True:
|
||||
1. `path` should be a folder
|
||||
2. `url` should point to a .tar.gz file which contains a single folder at
|
||||
the root level.
|
||||
|
||||
Attributes:
|
||||
path: Relative path in local directory.
|
||||
url: GCS url to download the file(s).
|
||||
is_folder: Whether the path and url represents a folder.
|
||||
"""
|
||||
|
||||
path: str
|
||||
url: str
|
||||
is_folder: bool = False
|
||||
|
||||
def get_path(self) -> str:
|
||||
"""Gets the path of files saved in a local directory.
|
||||
|
||||
If the path doesn't exist, this method will download the file(s) from the
|
||||
provided url. The path is not cleaned up so it can be reused for subsequent
|
||||
calls to the same path.
|
||||
Folders are expected to be zipped in a .tar.gz file which will be extracted
|
||||
into self.path in the local directory.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the extracted folder does not have a singular root
|
||||
directory.
|
||||
|
||||
Returns:
|
||||
The absolute path to the downloaded file(s)
|
||||
"""
|
||||
tmpdir = tempfile.gettempdir()
|
||||
absolute_path = pathlib.Path(
|
||||
os.path.join(tmpdir, _TEMPDIR_FOLDER, self.path)
|
||||
)
|
||||
if not absolute_path.exists():
|
||||
print(f'Downloading {self.url} to {absolute_path}')
|
||||
r = requests.get(self.url, allow_redirects=True)
|
||||
if self.is_folder:
|
||||
# Use tempf to store the downloaded .tar.gz file
|
||||
tempf = tempfile.NamedTemporaryFile(suffix='.tar.gz', mode='wb')
|
||||
tempf.write(r.content)
|
||||
tarf = tarfile.open(tempf.name)
|
||||
# Use tmpdir to store the extracted contents of the .tar.gz file
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tarf.extractall(tmpdir)
|
||||
tarf.close()
|
||||
tempf.close()
|
||||
subdirs = os.listdir(tmpdir)
|
||||
# Make sure tmpdir only has one subdirectory
|
||||
if len(subdirs) > 1 or not os.path.isdir(
|
||||
os.path.join(tmpdir, subdirs[0])
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Extracted folder from {self.url} doesn't contain a "
|
||||
f'single root directory: {subdirs}'
|
||||
)
|
||||
# Create the parent dir of absolute_path and copy the contents of the
|
||||
# top level dir in the .tar.gz file into absolute_path.
|
||||
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
|
||||
shutil.copytree(os.path.join(tmpdir, subdirs[0]), absolute_path)
|
||||
else:
|
||||
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
|
||||
with open(absolute_path, 'wb') as f:
|
||||
f.write(r.content)
|
||||
return str(absolute_path)
|
||||
|
||||
|
||||
# TODO Remove after text_classifier supports downloading on demand.
|
||||
def get_absolute_path(file_path: str) -> str:
|
||||
"""Gets the absolute path of a file in the model_maker directory.
|
||||
|
||||
|
|
|
@ -12,13 +12,68 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
from absl.testing import absltest
|
||||
import requests
|
||||
|
||||
from mediapipe.model_maker.python.core.utils import file_util
|
||||
|
||||
|
||||
class FileUtilTest(absltest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
mock_gettempdir = unittest_mock.patch.object(
|
||||
tempfile,
|
||||
'gettempdir',
|
||||
return_value=self.create_tempdir(),
|
||||
autospec=True,
|
||||
)
|
||||
self.mock_gettempdir = mock_gettempdir.start()
|
||||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_get_path(self):
|
||||
path = 'gesture_recognizer/hand_landmark_full.tflite'
|
||||
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
|
||||
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
|
||||
model_path = downloaded_files.get_path()
|
||||
self.assertTrue(os.path.exists(model_path))
|
||||
self.assertGreater(os.path.getsize(model_path), 0)
|
||||
|
||||
def test_get_path_folder(self):
|
||||
folder_contents = [
|
||||
'keras_metadata.pb',
|
||||
'saved_model.pb',
|
||||
'assets/vocab.txt',
|
||||
'variables/variables.data-00000-of-00001',
|
||||
'variables/variables.index',
|
||||
]
|
||||
path = 'text_classifier/mobilebert_tiny'
|
||||
url = (
|
||||
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz'
|
||||
)
|
||||
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=True)
|
||||
model_path = downloaded_files.get_path()
|
||||
self.assertTrue(os.path.exists(model_path))
|
||||
for file_name in folder_contents:
|
||||
file_path = os.path.join(model_path, file_name)
|
||||
self.assertTrue(os.path.exists(file_path))
|
||||
self.assertGreater(os.path.getsize(file_path), 0)
|
||||
|
||||
@unittest_mock.patch.object(requests, 'get', wraps=requests.get)
|
||||
def test_get_path_multiple_calls(self, mock_get):
|
||||
path = 'gesture_recognizer/hand_landmark_full.tflite'
|
||||
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
|
||||
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
|
||||
model_path = downloaded_files.get_path()
|
||||
self.assertTrue(os.path.exists(model_path))
|
||||
self.assertGreater(os.path.getsize(model_path), 0)
|
||||
model_path_2 = downloaded_files.get_path()
|
||||
self.assertEqual(model_path, model_path_2)
|
||||
self.assertEqual(mock_get.call_count, 1)
|
||||
|
||||
def test_get_absolute_path(self):
|
||||
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
|
||||
absolute_path = file_util.get_absolute_path(test_file)
|
||||
|
|
|
@ -48,7 +48,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
|
@ -90,7 +89,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"@com_google_absl//absl/status",
|
||||
|
|
|
@ -40,7 +40,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
|
||||
|
@ -68,7 +67,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
|||
using ::tflite::ProcessUnit;
|
||||
using ::tflite::TensorMetadata;
|
||||
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
||||
using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
||||
using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>;
|
||||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
|
@ -455,12 +454,13 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
}
|
||||
|
||||
// If output tensors are quantized, they must be dequantized first.
|
||||
TensorsSource dequantized_tensors(&tensors_in);
|
||||
TensorsSource dequantized_tensors = tensors_in;
|
||||
if (options.has_quantized_outputs()) {
|
||||
GenericNode* tensors_dequantization_node =
|
||||
&graph.AddNode("TensorsDequantizationCalculator");
|
||||
tensors_in >> tensors_dequantization_node->In(kTensorsTag);
|
||||
dequantized_tensors = {tensors_dequantization_node, kTensorsTag};
|
||||
dequantized_tensors = tensors_dequantization_node->Out(kTensorsTag)
|
||||
.Cast<std::vector<Tensor>>();
|
||||
}
|
||||
|
||||
// If there are multiple classification heads, the output tensors need to be
|
||||
|
@ -477,7 +477,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
auto* range = split_tensor_vector_options.add_ranges();
|
||||
range->set_begin(i);
|
||||
range->set_end(i + 1);
|
||||
split_tensors.emplace_back(split_tensor_vector_node, i);
|
||||
split_tensors.push_back(
|
||||
split_tensor_vector_node->Out(i).Cast<std::vector<Tensor>>());
|
||||
}
|
||||
dequantized_tensors >> split_tensor_vector_node->In(0);
|
||||
} else {
|
||||
|
@ -494,8 +495,9 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
|||
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
|
||||
.CopyFrom(options.score_calibration_options().at(i));
|
||||
split_tensors[i] >> score_calibration_node->In(kScoresTag);
|
||||
calibrated_tensors.emplace_back(score_calibration_node,
|
||||
kCalibratedScoresTag);
|
||||
calibrated_tensors.push_back(
|
||||
score_calibration_node->Out(kCalibratedScoresTag)
|
||||
.Cast<std::vector<Tensor>>());
|
||||
} else {
|
||||
calibrated_tensors.emplace_back(split_tensors[i]);
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
@ -51,8 +50,6 @@ using ::mediapipe::api2::builder::Graph;
|
|||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using TensorsSource =
|
||||
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
|
||||
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
|
||||
|
@ -229,12 +226,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
|
|||
Source<std::vector<Tensor>> tensors_in,
|
||||
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||
// If output tensors are quantized, they must be dequantized first.
|
||||
TensorsSource dequantized_tensors(&tensors_in);
|
||||
Source<std::vector<Tensor>> dequantized_tensors = tensors_in;
|
||||
if (options.has_quantized_outputs()) {
|
||||
GenericNode& tensors_dequantization_node =
|
||||
graph.AddNode("TensorsDequantizationCalculator");
|
||||
tensors_in >> tensors_dequantization_node.In(kTensorsTag);
|
||||
dequantized_tensors = {&tensors_dequantization_node, kTensorsTag};
|
||||
dequantized_tensors = tensors_dequantization_node.Out(kTensorsTag)
|
||||
.Cast<std::vector<Tensor>>();
|
||||
}
|
||||
|
||||
// Adds TensorsToEmbeddingsCalculator.
|
||||
|
|
|
@ -14,12 +14,6 @@
|
|||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
cc_library(
|
||||
name = "source_or_node_output",
|
||||
hdrs = ["source_or_node_output.h"],
|
||||
deps = ["//mediapipe/framework/api2:builder"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cosine_similarity",
|
||||
srcs = ["cosine_similarity.cc"],
|
||||
|
|
|
@ -1,66 +0,0 @@
|
|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_
|
||||
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
|
||||
// Helper class representing either a Source object or a GenericNode output.
|
||||
//
|
||||
// Source and MultiSource (the output of a GenericNode) are widely incompatible,
|
||||
// but being able to represent either of these in temporary variables and
|
||||
// connect them later on facilitates graph building.
|
||||
template <typename T>
|
||||
class SourceOrNodeOutput {
|
||||
public:
|
||||
SourceOrNodeOutput() = delete;
|
||||
// The caller is responsible for ensuring 'source' outlives this object.
|
||||
explicit SourceOrNodeOutput(mediapipe::api2::builder::Source<T>* source)
|
||||
: source_(source) {}
|
||||
// The caller is responsible for ensuring 'node' outlives this object.
|
||||
SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node,
|
||||
std::string tag)
|
||||
: node_(node), tag_(tag) {}
|
||||
// The caller is responsible for ensuring 'node' outlives this object.
|
||||
SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, int index)
|
||||
: node_(node), index_(index) {}
|
||||
|
||||
// Connects the source or node output to the provided destination.
|
||||
template <typename U>
|
||||
void operator>>(const U& dest) {
|
||||
if (source_ != nullptr) {
|
||||
*source_ >> dest;
|
||||
} else {
|
||||
if (index_ < 0) {
|
||||
node_->Out(tag_) >> dest;
|
||||
} else {
|
||||
node_->Out(index_) >> dest;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
mediapipe::api2::builder::Source<T>* source_ = nullptr;
|
||||
mediapipe::api2::builder::GenericNode* node_ = nullptr;
|
||||
std::string tag_ = "";
|
||||
int index_ = -1;
|
||||
};
|
||||
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_
|
|
@ -139,5 +139,32 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
|||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kMobileBert);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result0,
|
||||
text_embedder->Embed("When you go to this restaurant, they hold the "
|
||||
"pancake upside-down before they hand it "
|
||||
"to you. It's a great gimmick."));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result1,
|
||||
text_embedder->Embed(
|
||||
"Let's make a plan to steal the declaration of independence."));
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
// TODO: The similarity should likely be lower
|
||||
EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
||||
|
|
|
@ -140,6 +140,7 @@ cc_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
|
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
|
@ -68,6 +69,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
|
|||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
|
||||
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
|
||||
constexpr char kRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
|
||||
constexpr char kPalmRectsTag[] = "PALM_RECTS";
|
||||
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
|
||||
constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task";
|
||||
constexpr char kHandGestureRecognizerBundleAssetName[] =
|
||||
"hand_gesture_recognizer.task";
|
||||
|
@ -77,6 +81,9 @@ struct GestureRecognizerOutputs {
|
|||
Source<std::vector<ClassificationList>> handedness;
|
||||
Source<std::vector<NormalizedLandmarkList>> hand_landmarks;
|
||||
Source<std::vector<LandmarkList>> hand_world_landmarks;
|
||||
Source<std::vector<NormalizedRect>> hand_rects_next_frame;
|
||||
Source<std::vector<NormalizedRect>> palm_rects;
|
||||
Source<std::vector<Detection>> palm_detections;
|
||||
Source<Image> image;
|
||||
};
|
||||
|
||||
|
@ -135,9 +142,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform hand gesture recognition on.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform landmarks
|
||||
// detection on.
|
||||
// detection on. If not provided, whole image is used for gesture
|
||||
// recognition.
|
||||
//
|
||||
// Outputs:
|
||||
// HAND_GESTURES - std::vector<ClassificationList>
|
||||
|
@ -208,11 +216,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable()));
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
|
||||
BuildGestureRecognizerGraph(
|
||||
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_gesture_recognition_output,
|
||||
BuildGestureRecognizerGraph(
|
||||
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
hand_gesture_recognition_output.gesture >>
|
||||
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
|
||||
hand_gesture_recognition_output.handedness >>
|
||||
|
@ -222,6 +231,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
hand_gesture_recognition_output.hand_world_landmarks >>
|
||||
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
|
||||
hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)];
|
||||
hand_gesture_recognition_output.hand_rects_next_frame >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kRectNextFrameTag)];
|
||||
hand_gesture_recognition_output.palm_rects >>
|
||||
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
|
||||
hand_gesture_recognition_output.palm_detections >>
|
||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||
return graph.GetConfig();
|
||||
}
|
||||
|
||||
|
@ -279,7 +294,17 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
|
|||
/*handedness=*/handedness,
|
||||
/*hand_landmarks=*/hand_landmarks,
|
||||
/*hand_world_landmarks=*/hand_world_landmarks,
|
||||
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
|
||||
/*hand_rects_next_frame =*/
|
||||
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
|
||||
kRectNextFrameTag)],
|
||||
/*palm_rects =*/
|
||||
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
|
||||
kPalmRectsTag)],
|
||||
/*palm_detections =*/
|
||||
hand_landmarker_graph[Output<std::vector<Detection>>(
|
||||
kPalmDetectionsTag)],
|
||||
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -150,9 +150,9 @@ void ConfigureRectTransformationCalculator(
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform detection on.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// Describes image rotation and region of image to perform detection
|
||||
// on.
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform detection on. If
|
||||
// not provided, whole image is used for hand detection.
|
||||
//
|
||||
// Outputs:
|
||||
// PALM_DETECTIONS - std::vector<Detection>
|
||||
|
@ -197,11 +197,12 @@ class HandDetectorGraph : public core::ModelTaskGraph {
|
|||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<HandDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto hand_detection_outs,
|
||||
BuildHandDetectionSubgraph(
|
||||
sc->Options<HandDetectorGraphOptions>(),
|
||||
*model_resources, graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_detection_outs,
|
||||
BuildHandDetectionSubgraph(
|
||||
sc->Options<HandDetectorGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
hand_detection_outs.palm_detections >>
|
||||
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
|
||||
hand_detection_outs.hand_rects >>
|
||||
|
|
|
@ -136,9 +136,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
|
|||
// Inputs:
|
||||
// IMAGE - Image
|
||||
// Image to perform hand landmarks detection on.
|
||||
// NORM_RECT - NormalizedRect
|
||||
// NORM_RECT - NormalizedRect @Optional
|
||||
// Describes image rotation and region of image to perform landmarks
|
||||
// detection on.
|
||||
// detection on. If not provided, whole image is used for hand landmarks
|
||||
// detection.
|
||||
//
|
||||
// Outputs:
|
||||
// LANDMARKS: - std::vector<NormalizedLandmarkList>
|
||||
|
@ -218,11 +219,12 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
|||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||
.IsAvailable()));
|
||||
}
|
||||
ASSIGN_OR_RETURN(auto hand_landmarker_outputs,
|
||||
BuildHandLandmarkerGraph(
|
||||
sc->Options<HandLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kNormRectTag)], graph));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_landmarker_outputs,
|
||||
BuildHandLandmarkerGraph(
|
||||
sc->Options<HandLandmarkerGraphOptions>(),
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
hand_landmarker_outputs.landmark_lists >>
|
||||
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
|
||||
hand_landmarker_outputs.world_landmark_lists >>
|
||||
|
|
|
@ -243,11 +243,12 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
|
|||
const auto* model_resources,
|
||||
CreateModelResources<HandLandmarksDetectorGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(auto hand_landmark_detection_outs,
|
||||
BuildSingleHandLandmarksDetectorGraph(
|
||||
sc->Options<HandLandmarksDetectorGraphOptions>(),
|
||||
*model_resources, graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>(kHandRectTag)], graph));
|
||||
ASSIGN_OR_RETURN(
|
||||
auto hand_landmark_detection_outs,
|
||||
BuildSingleHandLandmarksDetectorGraph(
|
||||
sc->Options<HandLandmarksDetectorGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kHandRectTag)], graph));
|
||||
hand_landmark_detection_outs.hand_landmarks >>
|
||||
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
|
||||
hand_landmark_detection_outs.world_hand_landmarks >>
|
||||
|
|
|
@ -257,10 +257,12 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
|
|||
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
|
||||
}
|
||||
|
||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
|
||||
// TODO: fix this unit test after image segmenter handled post
|
||||
// processing correctly with rotated image.
|
||||
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
Image image, DecodeImageFromFile(
|
||||
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
|
||||
Image image,
|
||||
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
|
||||
auto options = std::make_unique<ImageSegmenterOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
|
||||
|
@ -271,7 +273,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
|
|||
ImageSegmenter::Create(std::move(options)));
|
||||
ImageProcessingOptions image_processing_options;
|
||||
image_processing_options.rotation_degrees = -90;
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
|
||||
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
|
||||
segmenter->Segment(image, image_processing_options));
|
||||
EXPECT_EQ(confidence_masks.size(), 21);
|
||||
|
||||
cv::Mat expected_mask =
|
||||
|
|
|
@ -74,7 +74,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
|
|
|
@ -34,7 +34,6 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
|
||||
|
@ -69,7 +68,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
|
|||
using ObjectDetectorOptionsProto =
|
||||
object_detector::proto::ObjectDetectorOptions;
|
||||
using TensorsSource =
|
||||
mediapipe::tasks::SourceOrNodeOutput<std::vector<mediapipe::Tensor>>;
|
||||
mediapipe::api2::builder::Source<std::vector<mediapipe::Tensor>>;
|
||||
|
||||
constexpr int kDefaultLocationsIndex = 0;
|
||||
constexpr int kDefaultCategoriesIndex = 1;
|
||||
|
@ -584,7 +583,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
auto post_processing_specs,
|
||||
BuildPostProcessingSpecs(task_options, metadata_extractor));
|
||||
// Calculators to perform score calibration, if specified in the metadata.
|
||||
TensorsSource calibrated_tensors = {&inference, kTensorTag};
|
||||
TensorsSource calibrated_tensors =
|
||||
inference.Out(kTensorTag).Cast<std::vector<Tensor>>();
|
||||
if (post_processing_specs.score_calibration_options.has_value()) {
|
||||
// Split tensors.
|
||||
auto* split_tensor_vector_node =
|
||||
|
@ -623,7 +623,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
|
|||
concatenate_tensor_vector_node->In(i);
|
||||
}
|
||||
}
|
||||
calibrated_tensors = {concatenate_tensor_vector_node, 0};
|
||||
calibrated_tensors =
|
||||
concatenate_tensor_vector_node->Out(0).Cast<std::vector<Tensor>>();
|
||||
}
|
||||
// Calculator to convert output tensors to a detection proto vector.
|
||||
// Connects TensorsToDetectionsCalculator's input stream to the output
|
||||
|
|
|
@ -26,16 +26,16 @@ NS_SWIFT_NAME(Embedding)
|
|||
@interface MPPEmbedding : NSObject
|
||||
|
||||
/**
|
||||
* @brief The Floating-point embedding.
|
||||
* @brief The embedding represented as an `NSArray` of `Float` values.
|
||||
* Empty if the embedder was configured to perform scalar quantization.
|
||||
*/
|
||||
@property(nonatomic, readonly, nullable) NSArray<NSNumber *> *floatEmbedding;
|
||||
|
||||
/**
|
||||
* @brief The Quantized embedding.
|
||||
* @brief The embedding represented as an `NSArray` of `UInt8` values.
|
||||
* Empty if the embedder was not configured to perform scalar quantization.
|
||||
*/
|
||||
@property(nonatomic, readonly, nullable) NSData *quantizedEmbedding;
|
||||
@property(nonatomic, readonly, nullable) NSArray<NSNumber *> *quantizedEmbedding;
|
||||
|
||||
/** The index of the embedder head these entries refer to. This is useful for multi-head models. */
|
||||
@property(nonatomic, readonly) NSInteger headIndex;
|
||||
|
@ -56,7 +56,7 @@ NS_SWIFT_NAME(Embedding)
|
|||
* embedding, head index and head name.
|
||||
*/
|
||||
- (instancetype)initWithFloatEmbedding:(nullable NSArray<NSNumber *> *)floatEmbedding
|
||||
quantizedEmbedding:(nullable NSData *)quantizedEmbedding
|
||||
quantizedEmbedding:(nullable NSArray<NSNumber *> *)quantizedEmbedding
|
||||
headIndex:(NSInteger)headIndex
|
||||
headName:(nullable NSString *)headName NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
@implementation MPPEmbedding
|
||||
|
||||
- (instancetype)initWithFloatEmbedding:(nullable NSArray<NSNumber *> *)floatEmbedding
|
||||
quantizedEmbedding:(nullable NSData *)quantizedEmbedding
|
||||
quantizedEmbedding:(nullable NSArray<NSNumber *> *)quantizedEmbedding
|
||||
headIndex:(NSInteger)headIndex
|
||||
headName:(nullable NSString *)headName {
|
||||
self = [super init];
|
||||
|
|
|
@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
/**
|
||||
* Options for setting up a `MPPTextEmbedder`.
|
||||
*/
|
||||
NS_SWIFT_NAME(TextEmbedderptions)
|
||||
NS_SWIFT_NAME(TextEmbedderOptions)
|
||||
@interface MPPTextEmbedderOptions : MPPTaskOptions <NSCopying>
|
||||
|
||||
/**
|
||||
|
|
|
@ -44,6 +44,7 @@ cc_binary(
|
|||
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
|
||||
],
|
||||
|
@ -176,6 +177,30 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "imagesegmenter",
|
||||
srcs = [
|
||||
"imagesegmenter/ImageSegmenter.java",
|
||||
"imagesegmenter/ImageSegmenterResult.java",
|
||||
],
|
||||
javacopts = [
|
||||
"-Xep:AndroidJdkLibsChecker:OFF",
|
||||
],
|
||||
manifest = "imagesegmenter/AndroidManifest.xml",
|
||||
deps = [
|
||||
":core",
|
||||
"//mediapipe/framework:calculator_options_java_proto_lite",
|
||||
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
|
||||
"//mediapipe/java/com/google/mediapipe/framework/image",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "imageembedder",
|
||||
srcs = [
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imagesegmenter">
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,462 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||
|
||||
import android.content.Context;
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
|
||||
import com.google.mediapipe.framework.AndroidPacketGetter;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.framework.PacketGetter;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.core.ErrorListener;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler;
|
||||
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
|
||||
import com.google.mediapipe.tasks.core.TaskInfo;
|
||||
import com.google.mediapipe.tasks.core.TaskOptions;
|
||||
import com.google.mediapipe.tasks.core.TaskRunner;
|
||||
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
|
||||
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Performs image segmentation on images.
|
||||
*
|
||||
* <p>Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a
|
||||
* user-defined callback function even for the synchronous API. This makes it possible for
|
||||
* ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in
|
||||
* the {@link ImageSegmenterOptions} for all {@link RunningMode}.
|
||||
*
|
||||
* <p>The API expects a TFLite model with,<a
|
||||
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
|
||||
*
|
||||
* <ul>
|
||||
* <li>Input image {@link MPImage}
|
||||
* <ul>
|
||||
* <li>The image that image segmenter runs on.
|
||||
* </ul>
|
||||
* <li>Output ImageSegmenterResult {@link ImageSgmenterResult}
|
||||
* <ul>
|
||||
* <li>An ImageSegmenterResult containing segmented masks.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*/
|
||||
public final class ImageSegmenter extends BaseVisionTaskApi {
|
||||
private static final String TAG = ImageSegmenter.class.getSimpleName();
|
||||
private static final String IMAGE_IN_STREAM_NAME = "image_in";
|
||||
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
|
||||
private static final List<String> INPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"GROUPED_SEGMENTATION:segmented_mask_out",
|
||||
"IMAGE:image_out",
|
||||
"SEGMENTATION:0:segmentation"));
|
||||
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
|
||||
*
|
||||
* @param context an Android {@link Context}.
|
||||
* @param segmenterOptions an {@link ImageSegmenterOptions} instance.
|
||||
* @throws MediaPipeException if there is an error during {@link ImageSegmenter} creation.
|
||||
*/
|
||||
public static ImageSegmenter createFromOptions(
|
||||
Context context, ImageSegmenterOptions segmenterOptions) {
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
new OutputHandler.OutputPacketConverter<ImageSegmenterResult, MPImage>() {
|
||||
@Override
|
||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||
throws MediaPipeException {
|
||||
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return ImageSegmenterResult.create(
|
||||
new ArrayList<>(),
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
||||
}
|
||||
List<MPImage> segmentedMasks = new ArrayList<>();
|
||||
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int imageFormat =
|
||||
segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK
|
||||
? MPImage.IMAGE_FORMAT_VEC32F1
|
||||
: MPImage.IMAGE_FORMAT_ALPHA;
|
||||
int imageListSize =
|
||||
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
|
||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||
if (!PacketGetter.getImageList(
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting segmented masks. It usually results from incorrect"
|
||||
+ " options of unsupported OutputType of given model.");
|
||||
}
|
||||
for (ByteBuffer buffer : buffersArray) {
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
|
||||
segmentedMasks.add(builder.build());
|
||||
}
|
||||
|
||||
return ImageSegmenterResult.create(
|
||||
segmentedMasks,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
segmenterOptions.runningMode(),
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public MPImage convertToTaskInput(List<Packet> packets) {
|
||||
return new BitmapImageBuilder(
|
||||
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
|
||||
.build();
|
||||
}
|
||||
});
|
||||
handler.setResultListener(segmenterOptions.resultListener());
|
||||
segmenterOptions.errorListener().ifPresent(handler::setErrorListener);
|
||||
TaskRunner runner =
|
||||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageSegmenterOptions>builder()
|
||||
.setTaskName(ImageSegmenter.class.getSimpleName())
|
||||
.setTaskRunningModeName(segmenterOptions.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setTaskOptions(segmenterOptions)
|
||||
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
|
||||
.build(),
|
||||
handler);
|
||||
return new ImageSegmenter(runner, segmenterOptions.runningMode());
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructor to initialize an {@link ImageSegmenter} from a {@link TaskRunner} and a {@link
|
||||
* RunningMode}.
|
||||
*
|
||||
* @param taskRunner a {@link TaskRunner}.
|
||||
* @param runningMode a mediapipe vision task {@link RunningMode}.
|
||||
*/
|
||||
private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) {
|
||||
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image with default image processing options,
|
||||
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java
|
||||
* doc for input image format.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segment(MPImage image) {
|
||||
segment(image, ImageProcessingOptions.builder().build());
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image, and the results will be available via
|
||||
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
|
||||
* when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO
|
||||
* update java doc for input image format.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
ImageSegmenterResult unused =
|
||||
(ImageSegmenterResult) processImageData(image, imageProcessingOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame with default image processing options,
|
||||
* i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentForVideo(MPImage image, long timestampMs) {
|
||||
segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame, and the results will be available via
|
||||
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
|
||||
* when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}.
|
||||
*
|
||||
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
* must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link HandLandmarker} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentForVideo(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
ImageSegmenterResult unused =
|
||||
(ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform hand landmarks detection with default image processing
|
||||
* options, i.e. without any rotation applied, and the results will be available via the {@link
|
||||
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
|
||||
* {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the image segmenter. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentAsync(MPImage image, long timestampMs) {
|
||||
segmentAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends live image data to perform image segmentation, and the results will be available via the
|
||||
* {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when
|
||||
* the {@link ImageSegmenter} is created with {@link RunningMode.LIVE_STREAM}.
|
||||
*
|
||||
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
|
||||
* sent to the image segmenter. The input timestamps must be monotonically increasing.
|
||||
*
|
||||
* <p>{@link ImageSegmenter} supports the following color space types:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link Bitmap.Config.ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
* input image before running inference. Note that region-of-interest is <b>not</b> supported
|
||||
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
|
||||
* this method throwing an IllegalArgumentException.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
* region-of-interest.
|
||||
* @throws MediaPipeException if there is an internal error.
|
||||
*/
|
||||
public void segmentAsync(
|
||||
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
|
||||
validateImageProcessingOptions(imageProcessingOptions);
|
||||
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
|
||||
}
|
||||
|
||||
/** Options for setting up an {@link ImageSegmenter}. */
|
||||
@AutoValue
|
||||
public abstract static class ImageSegmenterOptions extends TaskOptions {
|
||||
|
||||
/** Builder for {@link ImageSegmenterOptions}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder {
|
||||
/** Sets the base options for the image segmenter task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/**
|
||||
* Sets the running mode for the image segmenter task. Default to the image mode. Image
|
||||
* segmenter has three modes:
|
||||
*
|
||||
* <ul>
|
||||
* <li>IMAGE: The mode for segmenting image on single image inputs.
|
||||
* <li>VIDEO: The mode for segmenting image on the decoded frames of a video.
|
||||
* <li>LIVE_STREAM: The mode for for segmenting image on a live stream of input data, such
|
||||
* as from camera. In this mode, {@code setResultListener} must be called to set up a
|
||||
* listener to receive the recognition results asynchronously.
|
||||
* </ul>
|
||||
*/
|
||||
public abstract Builder setRunningMode(RunningMode value);
|
||||
|
||||
/**
|
||||
* The locale to use for display names specified through the TFLite Model Metadata, if any.
|
||||
* Defaults to English.
|
||||
*/
|
||||
public abstract Builder setDisplayNamesLocale(String value);
|
||||
|
||||
/** The output type from image segmenter. */
|
||||
public abstract Builder setOutputType(OutputType value);
|
||||
|
||||
/**
|
||||
* Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline
|
||||
* is done processing an image.
|
||||
*/
|
||||
public abstract Builder setResultListener(
|
||||
ResultListener<ImageSegmenterResult, MPImage> value);
|
||||
|
||||
/** Sets an optional {@link ErrorListener}}. */
|
||||
public abstract Builder setErrorListener(ErrorListener value);
|
||||
|
||||
abstract ImageSegmenterOptions autoBuild();
|
||||
|
||||
/**
|
||||
* Validates and builds the {@link ImageSegmenterOptions} instance.
|
||||
*
|
||||
* @throws IllegalArgumentException if the result listener and the running mode are not
|
||||
* properly configured. The result listener should only be set when the image segmenter is
|
||||
* in the live stream mode.
|
||||
*/
|
||||
public final ImageSegmenterOptions build() {
|
||||
ImageSegmenterOptions options = autoBuild();
|
||||
return options;
|
||||
}
|
||||
}
|
||||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract RunningMode runningMode();
|
||||
|
||||
abstract String displayNamesLocale();
|
||||
|
||||
abstract OutputType outputType();
|
||||
|
||||
abstract ResultListener<ImageSegmenterResult, MPImage> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
/** The output type of segmentation results. */
|
||||
public enum OutputType {
|
||||
// Gives a single output mask where each pixel represents the class which
|
||||
// the pixel in the original image was predicted to belong to.
|
||||
CATEGORY_MASK,
|
||||
// Gives a list of output masks where, for each mask, each pixel represents
|
||||
// the prediction confidence, usually in the [0, 1] range.
|
||||
CONFIDENCE_MASK
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setDisplayNamesLocale("en")
|
||||
.setOutputType(OutputType.CATEGORY_MASK)
|
||||
.setResultListener((result, image) -> {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an {@link ImageSegmenterOptions} to a {@link CalculatorOptions} protobuf message.
|
||||
*/
|
||||
@Override
|
||||
public CalculatorOptions convertToCalculatorOptionsProto() {
|
||||
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder taskOptionsBuilder =
|
||||
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder()
|
||||
.setBaseOptions(
|
||||
BaseOptionsProto.BaseOptions.newBuilder()
|
||||
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
|
||||
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
|
||||
.build())
|
||||
.setDisplayNamesLocale(displayNamesLocale());
|
||||
|
||||
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
||||
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
||||
if (outputType() == OutputType.CONFIDENCE_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
|
||||
} else if (outputType() == OutputType.CATEGORY_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
// TODO: remove this once activation is handled in metadata and grpah level.
|
||||
segmenterOptionsBuilder.setActivation(
|
||||
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
|
||||
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext,
|
||||
taskOptionsBuilder.build())
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
|
||||
* region-of-interest.
|
||||
*/
|
||||
private static void validateImageProcessingOptions(
|
||||
ImageProcessingOptions imageProcessingOptions) {
|
||||
if (imageProcessingOptions.regionOfInterest().isPresent()) {
|
||||
throw new IllegalArgumentException("ImageSegmenter doesn't support region-of-interest.");
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.TaskResult;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/** Represents the segmentation results generated by {@link ImageSegmenter}. */
|
||||
@AutoValue
|
||||
public abstract class ImageSegmenterResult implements TaskResult {
|
||||
|
||||
/**
|
||||
* Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
|
||||
*
|
||||
* @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType
|
||||
* is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is
|
||||
* CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_ALPHA format.
|
||||
* @param timestampMs a timestamp for this result.
|
||||
*/
|
||||
// TODO: consolidate output formats across platforms.
|
||||
static ImageSegmenterResult create(List<MPImage> segmentations, long timestampMs) {
|
||||
return new AutoValue_ImageSegmenterResult(
|
||||
Collections.unmodifiableList(segmentations), timestampMs);
|
||||
}
|
||||
|
||||
public abstract List<MPImage> segmentations();
|
||||
|
||||
@Override
|
||||
public abstract long timestampMs();
|
||||
}
|
|
@ -95,4 +95,24 @@ public class TextEmbedderTest {
|
|||
result1.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void classify_succeedsWithBertAndDifferentThemes() throws Exception {
|
||||
TextEmbedder textEmbedder =
|
||||
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
|
||||
|
||||
TextEmbedderResult result0 =
|
||||
textEmbedder.embed(
|
||||
"When you go to this restaurant, they hold the pancake upside-down before they hand "
|
||||
+ "it to you. It's a great gimmick.");
|
||||
TextEmbedderResult result1 =
|
||||
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
|
||||
|
||||
// Check cosine similarity.
|
||||
double similarity =
|
||||
TextEmbedder.cosineSimilarity(
|
||||
result0.embeddingResult().embeddings().get(0),
|
||||
result1.embeddingResult().embeddings().get(0));
|
||||
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.google.mediapipe.tasks.vision.imagesegmentertest"
|
||||
android:versionCode="1"
|
||||
android:versionName="1.0" >
|
||||
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
|
||||
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
|
||||
|
||||
<uses-sdk android:minSdkVersion="24"
|
||||
android:targetSdkVersion="30" />
|
||||
|
||||
<application
|
||||
android:label="imagesegmentertest"
|
||||
android:name="android.support.multidex.MultiDexApplication"
|
||||
android:taskAffinity="">
|
||||
<uses-library android:name="android.test.runner" />
|
||||
</application>
|
||||
|
||||
<instrumentation
|
||||
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
|
||||
android:targetPackage="com.google.mediapipe.tasks.vision.imagesegmentertest" />
|
||||
|
||||
</manifest>
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
# TODO: Enable this in OSS
|
|
@ -0,0 +1,427 @@
|
|||
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.vision.imagesegmenter;
|
||||
|
||||
import static com.google.common.truth.Truth.assertThat;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import android.content.res.AssetManager;
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.graphics.Color;
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.BitmapImageBuilder;
|
||||
import com.google.mediapipe.framework.image.ByteBufferExtractor;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.tasks.core.BaseOptions;
|
||||
import com.google.mediapipe.tasks.vision.core.RunningMode;
|
||||
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions;
|
||||
import java.io.InputStream;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Suite;
|
||||
import org.junit.runners.Suite.SuiteClasses;
|
||||
|
||||
/** Test for {@link ImageSegmenter}. */
|
||||
@RunWith(Suite.class)
|
||||
@SuiteClasses({ImageSegmenterTest.General.class, ImageSegmenterTest.RunningModeTest.class})
|
||||
public class ImageSegmenterTest {
|
||||
private static final String DEEPLAB_MODEL_FILE = "deeplabv3.tflite";
|
||||
private static final String SELFIE_128x128_MODEL_FILE = "selfie_segm_128_128_3.tflite";
|
||||
private static final String SELFIE_144x256_MODEL_FILE = "selfie_segm_144_256_3.tflite";
|
||||
private static final String CAT_IMAGE = "cat.jpg";
|
||||
private static final float GOLDEN_MASK_SIMILARITY = 0.96f;
|
||||
private static final int MAGNIFICATION_FACTOR = 10;
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class General extends ImageSegmenterTest {
|
||||
|
||||
@Test
|
||||
public void segment_successWithCategoryMask() throws Exception {
|
||||
final String inputImageName = "segmentation_input_rotation0.jpg";
|
||||
final String goldenImageName = "segmentation_golden_rotation0.png";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(1);
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
verifyCategoryMask(
|
||||
actualMaskBuffer,
|
||||
expectedMaskBuffer,
|
||||
GOLDEN_MASK_SIMILARITY,
|
||||
MAGNIFICATION_FACTOR);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithConfidenceMask() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWith128x128Segmentation() throws Exception {
|
||||
final String inputImageName = "mozart_square.jpg";
|
||||
final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(
|
||||
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(2);
|
||||
// Selfie category index 1.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
// TODO: enable this unit test once activation option is supported in metadata.
|
||||
// @Test
|
||||
// public void segment_successWith144x256Segmentation() throws Exception {
|
||||
// final String inputImageName = "mozart_square.jpg";
|
||||
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
|
||||
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
// ImageSegmenterOptions options =
|
||||
// ImageSegmenterOptions.builder()
|
||||
// .setBaseOptions(
|
||||
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
|
||||
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
// .setActivation(ImageSegmenterOptions.Activation.NONE)
|
||||
// .setResultListener(
|
||||
// (actualResult, inputImage) -> {
|
||||
// List<MPImage> segmentations = actualResult.segmentations();
|
||||
// assertThat(segmentations.size()).isEqualTo(1);
|
||||
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
|
||||
// verifyConfidenceMask(
|
||||
// actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
// })
|
||||
// .build();
|
||||
// ImageSegmenter imageSegmenter =
|
||||
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(),
|
||||
// options);
|
||||
// imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
// }
|
||||
}
|
||||
|
||||
@RunWith(AndroidJUnit4.class)
|
||||
public static final class RunningModeTest extends ImageSegmenterTest {
|
||||
@Test
|
||||
public void segment_failsWithCallingWrongApiInImageMode() throws Exception {
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentForVideo(
|
||||
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_failsWithCallingWrongApiInVideoMode() throws Exception {
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener((result, inputImage) -> {})
|
||||
.build();
|
||||
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
|
||||
exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() ->
|
||||
imageSegmenter.segmentForVideo(
|
||||
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
|
||||
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithImageMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.IMAGE)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithVideoMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.VIDEO)
|
||||
.setResultListener(
|
||||
(actualResult, inputImage) -> {
|
||||
List<MPImage> segmentations = actualResult.segmentations();
|
||||
assertThat(segmentations.size()).isEqualTo(21);
|
||||
// Cat category index 8.
|
||||
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
|
||||
verifyConfidenceMask(
|
||||
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
|
||||
for (int i = 0; i < 3; i++) {
|
||||
imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_successWithLiveStreamMode() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
MPImage expectedResult = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(segmenterResult, inputImage) -> {
|
||||
verifyConfidenceMask(
|
||||
segmenterResult.segmentations().get(8),
|
||||
expectedResult,
|
||||
GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
try (ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
imageSegmenter.segmentAsync(image, /* timestampsMs= */ i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void segment_failsWithOutOfOrderInputTimestamps() throws Exception {
|
||||
final String inputImageName = "cat.jpg";
|
||||
final String goldenImageName = "cat_mask.jpg";
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
MPImage expectedResult = getImageFromAsset(goldenImageName);
|
||||
ImageSegmenterOptions options =
|
||||
ImageSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setRunningMode(RunningMode.LIVE_STREAM)
|
||||
.setResultListener(
|
||||
(segmenterResult, inputImage) -> {
|
||||
verifyConfidenceMask(
|
||||
segmenterResult.segmentations().get(8),
|
||||
expectedResult,
|
||||
GOLDEN_MASK_SIMILARITY);
|
||||
})
|
||||
.build();
|
||||
try (ImageSegmenter imageSegmenter =
|
||||
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
|
||||
imageSegmenter.segmentAsync(image, /* timestampsMs= */ 1);
|
||||
MediaPipeException exception =
|
||||
assertThrows(
|
||||
MediaPipeException.class,
|
||||
() -> imageSegmenter.segmentAsync(image, /* timestampsMs= */ 0));
|
||||
assertThat(exception)
|
||||
.hasMessageThat()
|
||||
.contains("having a smaller timestamp than the processed timestamp");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void verifyCategoryMask(
|
||||
MPImage actualMask, MPImage goldenMask, float similarityThreshold, int magnificationFactor) {
|
||||
assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth());
|
||||
assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight());
|
||||
ByteBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask);
|
||||
Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask);
|
||||
int consistentPixels = 0;
|
||||
final int numPixels = actualMask.getWidth() * actualMask.getHeight();
|
||||
actualMaskBuffer.rewind();
|
||||
for (int y = 0; y < actualMask.getHeight(); y++) {
|
||||
for (int x = 0; x < actualMask.getWidth(); x++) {
|
||||
// RGB values are the same in the golden mask image.
|
||||
consistentPixels +=
|
||||
actualMaskBuffer.get() * magnificationFactor
|
||||
== Color.red(goldenMaskBitmap.getPixel(x, y))
|
||||
? 1
|
||||
: 0;
|
||||
}
|
||||
}
|
||||
assertThat((float) consistentPixels / numPixels).isGreaterThan(similarityThreshold);
|
||||
}
|
||||
|
||||
private static void verifyConfidenceMask(
|
||||
MPImage actualMask, MPImage goldenMask, float similarityThreshold) {
|
||||
assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth());
|
||||
assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight());
|
||||
FloatBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask).asFloatBuffer();
|
||||
Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask);
|
||||
FloatBuffer goldenMaskBuffer = getByteBufferFromBitmap(goldenMaskBitmap).asFloatBuffer();
|
||||
assertThat(
|
||||
calculateSoftIOU(
|
||||
actualMaskBuffer, goldenMaskBuffer, actualMask.getWidth() * actualMask.getHeight()))
|
||||
.isGreaterThan((double) similarityThreshold);
|
||||
}
|
||||
|
||||
private static MPImage getImageFromAsset(String filePath) throws Exception {
|
||||
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
|
||||
InputStream istr = assetManager.open(filePath);
|
||||
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
|
||||
}
|
||||
|
||||
private static ByteBuffer getByteBufferFromBitmap(Bitmap bitmap) {
|
||||
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(bitmap.getWidth() * bitmap.getHeight() * 4);
|
||||
for (int y = 0; y < bitmap.getHeight(); y++) {
|
||||
for (int x = 0; x < bitmap.getWidth(); x++) {
|
||||
byteBuffer.putFloat((float) Color.red(bitmap.getPixel(x, y)) / 255.f);
|
||||
}
|
||||
}
|
||||
byteBuffer.rewind();
|
||||
return byteBuffer;
|
||||
}
|
||||
|
||||
private static double calculateSum(FloatBuffer m) {
|
||||
m.rewind();
|
||||
double sum = 0;
|
||||
while (m.hasRemaining()) {
|
||||
sum += m.get();
|
||||
}
|
||||
m.rewind();
|
||||
return sum;
|
||||
}
|
||||
|
||||
private static FloatBuffer multiply(FloatBuffer m1, FloatBuffer m2, int bufferSize) {
|
||||
m1.rewind();
|
||||
m2.rewind();
|
||||
FloatBuffer buffer = FloatBuffer.allocate(bufferSize);
|
||||
while (m1.hasRemaining()) {
|
||||
buffer.put(m1.get() * m2.get());
|
||||
}
|
||||
m1.rewind();
|
||||
m2.rewind();
|
||||
buffer.rewind();
|
||||
return buffer;
|
||||
}
|
||||
|
||||
private static double calculateSoftIOU(FloatBuffer m1, FloatBuffer m2, int bufferSize) {
|
||||
double intersectionSum = calculateSum(multiply(m1, m2, bufferSize));
|
||||
double m1m1 = calculateSum(multiply(m1, m1.duplicate(), bufferSize));
|
||||
double m2m2 = calculateSum(multiply(m2, m2.duplicate(), bufferSize));
|
||||
double unionSum = m1m1 + m2m2 - intersectionSum;
|
||||
return unionSum > 0.0 ? intersectionSum / unionSum : 0.0;
|
||||
}
|
||||
}
|
|
@ -192,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase):
|
|||
self._check_embedding_value(result1, expected_result1_value)
|
||||
self._check_cosine_similarity(result0, result1, expected_similarity)
|
||||
|
||||
def test_embed_with_mobile_bert_and_different_themes(self):
|
||||
# Creates embedder.
|
||||
model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)
|
||||
)
|
||||
base_options = _BaseOptions(model_asset_path=model_path)
|
||||
options = _TextEmbedderOptions(base_options=base_options)
|
||||
embedder = _TextEmbedder.create_from_options(options)
|
||||
|
||||
# Extracts both embeddings.
|
||||
text0 = (
|
||||
'When you go to this restaurant, they hold the pancake upside-down '
|
||||
"before they hand it to you. It's a great gimmick."
|
||||
)
|
||||
result0 = embedder.embed(text0)
|
||||
|
||||
text1 = "Let's make a plan to steal the declaration of independence."
|
||||
result1 = embedder.embed(text1)
|
||||
|
||||
similarity = _TextEmbedder.cosine_similarity(
|
||||
result0.embeddings[0], result1.embeddings[0]
|
||||
)
|
||||
|
||||
# TODO: The similarity should likely be lower
|
||||
self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE)
|
||||
|
||||
# Closes the embedder explicitly when the embedder is not used in
|
||||
# a context.
|
||||
embedder.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -33,6 +33,7 @@ export declare interface Embedding {
|
|||
* perform scalar quantization.
|
||||
*/
|
||||
quantizedEmbedding?: Uint8Array;
|
||||
|
||||
/**
|
||||
* The index of the classifier head these categories refer to. This is
|
||||
* useful for multi-head models.
|
||||
|
|
|
@ -70,12 +70,12 @@ describe('computeCosineSimilarity', () => {
|
|||
|
||||
it('succeeds with quantized embeddings', () => {
|
||||
const u: Embedding = {
|
||||
quantizedEmbedding: new Uint8Array([255, 128, 128, 128]),
|
||||
quantizedEmbedding: new Uint8Array([127, 0, 0, 0]),
|
||||
headIndex: 0,
|
||||
headName: ''
|
||||
};
|
||||
const v: Embedding = {
|
||||
quantizedEmbedding: new Uint8Array([0, 128, 128, 128]),
|
||||
quantizedEmbedding: new Uint8Array([128, 0, 0, 0]),
|
||||
headIndex: 0,
|
||||
headName: ''
|
||||
};
|
||||
|
|
|
@ -38,7 +38,7 @@ export function computeCosineSimilarity(u: Embedding, v: Embedding): number {
|
|||
}
|
||||
|
||||
function convertToBytes(data: Uint8Array): number[] {
|
||||
return Array.from(data, v => v - 128);
|
||||
return Array.from(data, v => v > 127 ? v - 256 : v);
|
||||
}
|
||||
|
||||
function compute(u: number[], v: number[]) {
|
||||
|
|
|
@ -23,6 +23,7 @@ VISION_LIBS = [
|
|||
"//mediapipe/tasks/web/vision/hand_landmarker",
|
||||
"//mediapipe/tasks/web/vision/image_classifier",
|
||||
"//mediapipe/tasks/web/vision/image_embedder",
|
||||
"//mediapipe/tasks/web/vision/image_segmenter",
|
||||
"//mediapipe/tasks/web/vision/object_detector",
|
||||
]
|
||||
|
||||
|
|
|
@ -39,6 +39,23 @@ const classifications = imageClassifier.classify(image);
|
|||
|
||||
For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation.
|
||||
|
||||
## Image Segmentation
|
||||
|
||||
The MediaPipe Image Segmenter lets you segment an image into categories.
|
||||
|
||||
```
|
||||
const vision = await FilesetResolver.forVisionTasks(
|
||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||
);
|
||||
const imageSegmenter = await ImageSegmenter.createFromModelPath(vision,
|
||||
"model.tflite"
|
||||
);
|
||||
const image = document.getElementById("image") as HTMLImageElement;
|
||||
imageSegmenter.segment(image, (masks, width, height) => {
|
||||
...
|
||||
});
|
||||
```
|
||||
|
||||
## Gesture Recognition
|
||||
|
||||
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real
|
||||
|
|
58
mediapipe/tasks/web/vision/image_segmenter/BUILD
Normal file
58
mediapipe/tasks/web/vision/image_segmenter/BUILD
Normal file
|
@ -0,0 +1,58 @@
|
|||
# This contains the MediaPipe Image Segmenter Task.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
|
||||
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image_segmenter",
|
||||
srcs = ["image_segmenter.ts"],
|
||||
deps = [
|
||||
":image_segmenter_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_declaration(
|
||||
name = "image_segmenter_types",
|
||||
srcs = ["image_segmenter_options.d.ts"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_options",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "image_segmenter_test_lib",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"image_segmenter_test.ts",
|
||||
],
|
||||
deps = [
|
||||
":image_segmenter",
|
||||
":image_segmenter_types",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:task_runner_test_utils",
|
||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
||||
],
|
||||
)
|
||||
|
||||
jasmine_node_test(
|
||||
name = "image_segmenter_test",
|
||||
tags = ["nomsan"],
|
||||
deps = [":image_segmenter_test_lib"],
|
||||
)
|
300
mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts
Normal file
300
mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts
Normal file
|
@ -0,0 +1,300 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
|
||||
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
|
||||
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
|
||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||
|
||||
export * from './image_segmenter_options';
|
||||
export {ImageSource}; // Used in the public API
|
||||
|
||||
/**
|
||||
* The ImageSegmenter returns the segmentation result as a Uint8Array (when
|
||||
* the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
|
||||
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
|
||||
* for future usage.
|
||||
*/
|
||||
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
|
||||
|
||||
/**
|
||||
* A callback that receives the computed masks from the image segmenter. The
|
||||
* callback either receives a single element array with a category mask (as a
|
||||
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
|
||||
* The returned data is only valid for the duration of the callback. If
|
||||
* asynchronous processing is needed, all data needs to be copied before the
|
||||
* callback returns.
|
||||
*/
|
||||
export type SegmentationMaskCallback =
|
||||
(masks: SegmentationMask[], width: number, height: number) => void;
|
||||
|
||||
const IMAGE_STREAM = 'image_in';
|
||||
const NORM_RECT_STREAM = 'norm_rect';
|
||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/** Performs image segmentation on images. */
|
||||
export class ImageSegmenter extends VisionTaskRunner {
|
||||
private userCallback: SegmentationMaskCallback = () => {};
|
||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new image segmenter from the
|
||||
* provided options.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param imageSegmenterOptions The options for the Image Segmenter. Note
|
||||
* that either a path to the model asset or a model buffer needs to be
|
||||
* provided (via `baseOptions`).
|
||||
*/
|
||||
static createFromOptions(
|
||||
wasmFileset: WasmFileset,
|
||||
imageSegmenterOptions: ImageSegmenterOptions): Promise<ImageSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
imageSegmenterOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new image segmenter based on
|
||||
* the provided model asset buffer.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param modelAssetBuffer A binary representation of the model.
|
||||
*/
|
||||
static createFromModelBuffer(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetBuffer: Uint8Array): Promise<ImageSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
{baseOptions: {modelAssetBuffer}});
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new image segmenter based on
|
||||
* the path to the model asset.
|
||||
* @param wasmFileset A configuration object that provides the location of
|
||||
* the Wasm binary and its loader.
|
||||
* @param modelAssetPath The path to the model asset.
|
||||
*/
|
||||
static createFromModelPath(
|
||||
wasmFileset: WasmFileset,
|
||||
modelAssetPath: string): Promise<ImageSegmenter> {
|
||||
return VisionTaskRunner.createInstance(
|
||||
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
|
||||
{baseOptions: {modelAssetPath}});
|
||||
}
|
||||
|
||||
/** @hideconstructor */
|
||||
constructor(
|
||||
wasmModule: WasmModule,
|
||||
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
|
||||
super(
|
||||
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
|
||||
NORM_RECT_STREAM, /* roiAllowed= */ false);
|
||||
this.options = new ImageSegmenterGraphOptionsProto();
|
||||
this.segmenterOptions = new SegmenterOptionsProto();
|
||||
this.options.setSegmenterOptions(this.segmenterOptions);
|
||||
this.options.setBaseOptions(new BaseOptionsProto());
|
||||
}
|
||||
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
return this.options.getBaseOptions()!;
|
||||
}
|
||||
|
||||
protected override set baseOptions(proto: BaseOptionsProto) {
|
||||
this.options.setBaseOptions(proto);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets new options for the image segmenter.
|
||||
*
|
||||
* Calling `setOptions()` with a subset of options only affects those
|
||||
* options. You can reset an option back to its default value by
|
||||
* explicitly setting it to `undefined`.
|
||||
*
|
||||
* @param options The options for the image segmenter.
|
||||
*/
|
||||
override setOptions(options: ImageSegmenterOptions): Promise<void> {
|
||||
// Note that we have to support both JSPB and ProtobufJS, hence we
|
||||
// have to expliclity clear the values instead of setting them to
|
||||
// `undefined`.
|
||||
if (options.displayNamesLocale !== undefined) {
|
||||
this.options.setDisplayNamesLocale(options.displayNamesLocale);
|
||||
} else if ('displayNamesLocale' in options) { // Check for undefined
|
||||
this.options.clearDisplayNamesLocale();
|
||||
}
|
||||
|
||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
||||
} else {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
return super.applyOptions(options);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided single image and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `image`.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segment(image: ImageSource, callback: SegmentationMaskCallback): void;
|
||||
/**
|
||||
* Performs image segmentation on the provided single image and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `image`.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segment(
|
||||
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
segment(
|
||||
image: ImageSource,
|
||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||
SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
const imageProcessingOptions =
|
||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
{};
|
||||
|
||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
callback!;
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
this.userCallback = () => {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource, timestamp: number,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||
timestamp: number, callback: SegmentationMaskCallback): void;
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource,
|
||||
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
||||
timestampOrCallback: number|SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
const imageProcessingOptions =
|
||||
typeof timestampOrImageProcessingOptions !== 'number' ?
|
||||
timestampOrImageProcessingOptions :
|
||||
{};
|
||||
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||
timestampOrImageProcessingOptions :
|
||||
timestampOrCallback as number;
|
||||
|
||||
this.userCallback = typeof timestampOrCallback === 'function' ?
|
||||
timestampOrCallback :
|
||||
callback!;
|
||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||
this.userCallback = () => {};
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
protected override refreshGraph(): void {
|
||||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
ImageSegmenterGraphOptionsProto.ext, this.options);
|
||||
|
||||
const segmenterNode = new CalculatorGraphConfig.Node();
|
||||
segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH);
|
||||
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||
segmenterNode.addOutputStream(
|
||||
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
|
||||
segmenterNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(segmenterNode);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
|
||||
if (masks.length === 0) {
|
||||
this.userCallback([], 0, 0);
|
||||
} else {
|
||||
this.userCallback(
|
||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
||||
}
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
}
|
||||
}
|
||||
|
||||
|
41
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts
vendored
Normal file
41
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts
vendored
Normal file
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options';
|
||||
|
||||
/** Options to configure the MediaPipe Image Segmenter Task */
|
||||
export interface ImageSegmenterOptions extends VisionTaskOptions {
|
||||
/**
|
||||
* The locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any. Defaults to English.
|
||||
*/
|
||||
displayNamesLocale?: string|undefined;
|
||||
|
||||
/**
|
||||
* The output type of segmentation results.
|
||||
*
|
||||
* The two supported modes are:
|
||||
* - Category Mask: Gives a single output mask where each pixel represents
|
||||
* the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
||||
* each mask, the pixel represents the prediction
|
||||
* confidence, usually in the [0.0, 0.1] range.
|
||||
*
|
||||
* Defaults to `CATEGORY_MASK`.
|
||||
*/
|
||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
||||
}
|
|
@ -0,0 +1,215 @@
|
|||
/**
|
||||
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import 'jasmine';
|
||||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||
|
||||
import {ImageSegmenter} from './image_segmenter';
|
||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||
|
||||
class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||
calculatorName = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||
attachListenerSpies: jasmine.Spy[] = [];
|
||||
graph: CalculatorGraphConfig|undefined;
|
||||
|
||||
fakeWasmModule: SpyWasmModule;
|
||||
imageVectorListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
|
||||
constructor() {
|
||||
super(createSpyWasmModule(), /* glCanvas= */ null);
|
||||
this.fakeWasmModule =
|
||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||
|
||||
this.attachListenerSpies[0] =
|
||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('segmented_masks');
|
||||
this.imageVectorListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
});
|
||||
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
|
||||
}
|
||||
}
|
||||
|
||||
describe('ImageSegmenter', () => {
|
||||
let imageSegmenter: ImageSegmenterFake;
|
||||
|
||||
beforeEach(async () => {
|
||||
addJasmineCustomFloatEqualityTester();
|
||||
imageSegmenter = new ImageSegmenterFake();
|
||||
await imageSegmenter.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
});
|
||||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(imageSegmenter);
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
|
||||
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
|
||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
});
|
||||
|
||||
it('can use custom models', async () => {
|
||||
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
|
||||
const newModelBase64 = Buffer.from(newModel).toString('base64');
|
||||
await imageSegmenter.setOptions({
|
||||
baseOptions: {
|
||||
modelAssetBuffer: newModel,
|
||||
}
|
||||
});
|
||||
|
||||
verifyGraph(
|
||||
imageSegmenter,
|
||||
/* expectedCalculatorOptions= */ undefined,
|
||||
/* expectedBaseOptions= */
|
||||
[
|
||||
'modelAsset', {
|
||||
fileContent: newModelBase64,
|
||||
fileName: undefined,
|
||||
fileDescriptorMeta: undefined,
|
||||
filePointerMeta: undefined
|
||||
}
|
||||
]);
|
||||
});
|
||||
|
||||
it('merges options', async () => {
|
||||
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||
});
|
||||
|
||||
describe('setOptions()', () => {
|
||||
interface TestCase {
|
||||
optionName: keyof ImageSegmenterOptions;
|
||||
fieldPath: string[];
|
||||
userValue: unknown;
|
||||
graphValue: unknown;
|
||||
defaultValue: unknown;
|
||||
}
|
||||
|
||||
const testCases: TestCase[] = [
|
||||
{
|
||||
optionName: 'displayNamesLocale',
|
||||
fieldPath: ['displayNamesLocale'],
|
||||
userValue: 'en',
|
||||
graphValue: 'en',
|
||||
defaultValue: 'en'
|
||||
},
|
||||
{
|
||||
optionName: 'outputType',
|
||||
fieldPath: ['segmenterOptions', 'outputType'],
|
||||
userValue: 'CONFIDENCE_MASK',
|
||||
graphValue: 2,
|
||||
defaultValue: 1
|
||||
},
|
||||
];
|
||||
|
||||
for (const testCase of testCases) {
|
||||
it(`can set ${testCase.optionName}`, async () => {
|
||||
await imageSegmenter.setOptions(
|
||||
{[testCase.optionName]: testCase.userValue});
|
||||
verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
|
||||
});
|
||||
|
||||
it(`can clear ${testCase.optionName}`, async () => {
|
||||
await imageSegmenter.setOptions(
|
||||
{[testCase.optionName]: testCase.userValue});
|
||||
verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
|
||||
await imageSegmenter.setOptions({[testCase.optionName]: undefined});
|
||||
verifyGraph(
|
||||
imageSegmenter, [testCase.fieldPath, testCase.defaultValue]);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
it('doesn\'t support region of interest', () => {
|
||||
expect(() => {
|
||||
imageSegmenter.segment(
|
||||
{} as HTMLImageElement,
|
||||
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
|
||||
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||
});
|
||||
|
||||
it('supports category masks', (done) => {
|
||||
const mask = new Uint8Array([1, 2, 3, 4]);
|
||||
|
||||
// Pass the test data to our listener
|
||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
imageSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask, width: 2, height: 2},
|
||||
],
|
||||
/* timestamp= */ 1337);
|
||||
});
|
||||
|
||||
// Invoke the image segmenter
|
||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(1);
|
||||
expect(masks[0]).toEqual(mask);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
done();
|
||||
});
|
||||
});
|
||||
|
||||
it('supports confidence masks', async () => {
|
||||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||
|
||||
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
|
||||
// Pass the test data to our listener
|
||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
imageSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask1, width: 2, height: 2},
|
||||
{data: mask2, width: 2, height: 2},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(2);
|
||||
expect(masks[0]).toEqual(mask1);
|
||||
expect(masks[1]).toEqual(mask2);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
|
@ -19,6 +19,7 @@ import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vis
|
|||
import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
||||
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
|
||||
|
||||
// Declare the variables locally so that Rollup in OSS includes them explicitly
|
||||
|
@ -28,6 +29,7 @@ const GestureRecognizer = GestureRecognizerImpl;
|
|||
const HandLandmarker = HandLandmarkerImpl;
|
||||
const ImageClassifier = ImageClassifierImpl;
|
||||
const ImageEmbedder = ImageEmbedderImpl;
|
||||
const ImageSegmenter = ImageSegementerImpl;
|
||||
const ObjectDetector = ObjectDetectorImpl;
|
||||
|
||||
export {
|
||||
|
@ -36,5 +38,6 @@ export {
|
|||
HandLandmarker,
|
||||
ImageClassifier,
|
||||
ImageEmbedder,
|
||||
ImageSegmenter,
|
||||
ObjectDetector
|
||||
};
|
||||
|
|
|
@ -19,4 +19,5 @@ export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
|
|||
export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
|
||||
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
|
||||
export * from '../../../tasks/web/vision/image_embedder/image_embedder';
|
||||
export * from '../../../tasks/web/vision/image_segmenter/image_segmenter';
|
||||
export * from '../../../tasks/web/vision/object_detector/object_detector';
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel
|
||||
index 9fceffe..e7f9d01 100644
|
||||
--- a/absl/time/internal/cctz/BUILD.bazel
|
||||
+++ b/absl/time/internal/cctz/BUILD.bazel
|
||||
@@ -69,8 +69,5 @@ cc_library(
|
||||
"include/cctz/zone_info_source.h",
|
||||
],
|
||||
linkopts = select({
|
||||
- ":osx": [
|
||||
- "-framework Foundation",
|
||||
- ],
|
||||
":ios": [
|
||||
"-framework Foundation",
|
||||
],
|
13
third_party/com_google_absl_windows_patch.diff
vendored
Normal file
13
third_party/com_google_absl_windows_patch.diff
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
diff --git a/absl/types/compare.h b/absl/types/compare.h
|
||||
index 19b076e..0201004 100644
|
||||
--- a/absl/types/compare.h
|
||||
+++ b/absl/types/compare.h
|
||||
@@ -84,7 +84,7 @@ enum class ncmp : value_type { unordered = -127 };
|
||||
// based on whether the feature is supported. Note: we can't use
|
||||
// ABSL_INTERNAL_INLINE_CONSTEXPR here because the variables here are of
|
||||
// incomplete types so they need to be defined after the types are complete.
|
||||
-#ifdef __cpp_inline_variables
|
||||
+#if defined(__cpp_inline_variables) && !(defined(_MSC_VER) && _MSC_VER <= 1916)
|
||||
|
||||
// A no-op expansion that can be followed by a semicolon at class level.
|
||||
#define ABSL_COMPARE_INLINE_BASECLASS_DECL(name) static_assert(true, "")
|
Loading…
Reference in New Issue
Block a user