Merge branch 'master' into ios-text-embedder

This commit is contained in:
Prianka Liz Kariat 2023-02-01 19:11:01 +05:30
commit cd1cb87ff6
66 changed files with 2755 additions and 256 deletions

View File

@ -22,21 +22,20 @@ bazel_skylib_workspace()
load("@bazel_skylib//lib:versions.bzl", "versions")
versions.check(minimum_bazel_version = "3.7.2")
# ABSL cpp library lts_2021_03_24, patch 2.
# ABSL cpp library lts_2023_01_25.
http_archive(
name = "com_google_absl",
urls = [
"https://github.com/abseil/abseil-cpp/archive/refs/tags/20220623.1.tar.gz",
"https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz",
],
# Remove after https://github.com/abseil/abseil-cpp/issues/326 is solved.
patches = [
"@//third_party:com_google_absl_f863b622fe13612433fdf43f76547d5edda0c93001.diff"
"@//third_party:com_google_absl_windows_patch.diff"
],
patch_args = [
"-p1",
],
strip_prefix = "abseil-cpp-20220623.1",
sha256 = "91ac87d30cc6d79f9ab974c51874a704de9c2647c40f6932597329a282217ba8"
strip_prefix = "abseil-cpp-20230125.0",
sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21"
)
http_archive(

View File

@ -4,12 +4,10 @@ py_binary(
name = "build_py_api_docs",
srcs = ["build_py_api_docs.py"],
deps = [
"//mediapipe",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
"//third_party/py/tensorflow_docs",
"//third_party/py/mediapipe",
"//third_party/py/tensorflow_docs/api_generator:generate_lib",
"//third_party/py/tensorflow_docs/api_generator:public_api",
],
)

View File

@ -35,7 +35,7 @@ install --user six`.
```bash
$ cd $HOME
$ git clone https://github.com/google/mediapipe.git
$ git clone -depth 1 https://github.com/google/mediapipe.git
# Change directory into MediaPipe root directory
$ cd mediapipe
@ -287,7 +287,7 @@ build issues.
2. Checkout MediaPipe repository.
```bash
$ git clone https://github.com/google/mediapipe.git
$ git clone -depth 1 https://github.com/google/mediapipe.git
# Change directory into MediaPipe root directory
$ cd mediapipe
@ -416,7 +416,7 @@ build issues.
3. Checkout MediaPipe repository.
```bash
$ git clone https://github.com/google/mediapipe.git
$ git clone -depth 1 https://github.com/google/mediapipe.git
$ cd mediapipe
```
@ -590,7 +590,7 @@ next section.
7. Checkout MediaPipe repository.
```
C:\Users\Username\mediapipe_repo> git clone https://github.com/google/mediapipe.git
C:\Users\Username\mediapipe_repo> git clone -depth 1 https://github.com/google/mediapipe.git
# Change directory into MediaPipe root directory
C:\Users\Username\mediapipe_repo> cd mediapipe
@ -680,7 +680,7 @@ cameras. Alternatively, you use a video file as input.
6. Checkout MediaPipe repository.
```bash
username@DESKTOP-TMVLBJ1:~$ git clone https://github.com/google/mediapipe.git
username@DESKTOP-TMVLBJ1:~$ git clone -depth 1 https://github.com/google/mediapipe.git
username@DESKTOP-TMVLBJ1:~$ cd mediapipe
```
@ -771,7 +771,7 @@ This will use a Docker image that will isolate mediapipe's installation from the
2. Build a docker image with tag "mediapipe".
```bash
$ git clone https://github.com/google/mediapipe.git
$ git clone -depth 1 https://github.com/google/mediapipe.git
$ cd mediapipe
$ docker build --tag=mediapipe .

View File

@ -1329,6 +1329,7 @@ cc_library(
hdrs = ["merge_to_vector_calculator.h"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",

View File

@ -49,7 +49,7 @@ namespace mediapipe {
// calculator: "EndLoopWithOutputCalculator"
// input_stream: "ITEM:output_of_loop_body" # ItemU @loop_internal_ts
// input_stream: "BATCH_END:ext_ts" # Timestamp @loop_internal_ts
// output_stream: "OUTPUT:aggregated_result" # IterableU @ext_ts
// output_stream: "ITERABLE:aggregated_result" # IterableU @ext_ts
// }
//
// Input streams tagged with "CLONE" are cloned to the corresponding output

View File

@ -1054,6 +1054,7 @@ cc_test(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/util:packet_test_util",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:protos_all_cc",
],

View File

@ -102,7 +102,7 @@ absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) {
}
auto output =
absl::make_unique<std::vector<float>>(input_tensor.NumElements());
const auto& tensor_values = input_tensor.flat<float>();
const auto& tensor_values = input_tensor.unaligned_flat<float>();
for (int i = 0; i < input_tensor.NumElements(); ++i) {
output->at(i) = tensor_values(i);
}

View File

@ -16,6 +16,7 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/util/packet_test_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
@ -129,5 +130,28 @@ TEST_F(TensorToVectorFloatCalculatorTest, FlattenShouldTakeAllDimensions) {
}
}
TEST_F(TensorToVectorFloatCalculatorTest, AcceptsUnalignedTensors) {
SetUpRunner(/*tensor_is_2d=*/false, /*flatten_nd=*/false);
const tf::TensorShape tensor_shape(std::vector<tf::int64>{2, 5});
tf::Tensor tensor(tf::DT_FLOAT, tensor_shape);
auto slice = tensor.Slice(1, 1).flat<float>();
for (int i = 0; i < 5; ++i) {
slice(i) = i;
}
auto input_tensor = tensor.SubSlice(1);
// Ensure that the input tensor is unaligned.
ASSERT_FALSE(input_tensor.IsAligned());
runner_->MutableInputs()->Index(0).packets.push_back(
MakePacket<tf::Tensor>(input_tensor).At(Timestamp(5)));
ASSERT_TRUE(runner_->Run().ok());
EXPECT_THAT(runner_->Outputs().Index(0).packets,
ElementsAre(PacketContainsTimestampAndPayload<std::vector<float>>(
Timestamp(5), std::vector<float>({0, 1, 2, 3, 4}))));
}
} // namespace
} // namespace mediapipe

View File

@ -1051,6 +1051,7 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -26,12 +26,21 @@ using ::mediapipe::api2::test::FooBar1;
TEST(BuilderTest, BuildGraph) {
Graph graph;
// Graph inputs.
Stream<AnyType> base = graph.In("IN").SetName("base");
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
auto& foo = graph.AddNode("Foo");
base >> foo.In("BASE");
side >> foo.SideIn("SIDE");
Stream<AnyType> foo_out = foo.Out("OUT");
auto& bar = graph.AddNode("Bar");
graph.In("IN").SetName("base") >> foo.In("BASE");
graph.SideIn("SIDE").SetName("side") >> foo.SideIn("SIDE");
foo.Out("OUT") >> bar.In("IN");
bar.Out("OUT").SetName("out") >> graph.Out("OUT");
foo_out >> bar.In("IN");
Stream<AnyType> bar_out = bar.Out("OUT");
// Graph outputs.
bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -87,6 +96,7 @@ TEST(BuilderTest, CopyableStream) {
TEST(BuilderTest, BuildGraphWithFunctions) {
Graph graph;
// Graph inputs.
Stream<int> base = graph.In("IN").SetName("base").Cast<int>();
SidePacket<float> side = graph.SideIn("SIDE").SetName("side").Cast<float>();
@ -105,6 +115,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) {
};
Stream<double> bar_out = bar_fn(foo_out, graph);
// Graph outputs.
bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
@ -130,12 +141,21 @@ TEST(BuilderTest, BuildGraphWithFunctions) {
template <class FooT>
void BuildGraphTypedTest() {
Graph graph;
// Graph inputs.
Stream<AnyType> base = graph.In("IN").SetName("base");
SidePacket<AnyType> side = graph.SideIn("SIDE").SetName("side");
auto& foo = graph.AddNode<FooT>();
base >> foo.In(MPP_TAG("BASE"));
side >> foo.SideIn(MPP_TAG("BIAS"));
Stream<float> foo_out = foo.Out(MPP_TAG("OUT"));
auto& bar = graph.AddNode<Bar>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
graph.SideIn("SIDE").SetName("side") >> foo.SideIn(MPP_TAG("BIAS"));
foo.Out(MPP_TAG("OUT")) >> bar.In(MPP_TAG("IN"));
bar.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
foo_out >> bar.In(MPP_TAG("IN"));
Stream<AnyType> bar_out = bar.Out(MPP_TAG("OUT"));
// Graph outputs.
bar_out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
@ -165,12 +185,20 @@ TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<test::Foo2>(); }
TEST(BuilderTest, FanOut) {
Graph graph;
// Graph inputs.
Stream<AnyType> base = graph.In("IN").SetName("base");
auto& foo = graph.AddNode("Foo");
base >> foo.In("BASE");
Stream<AnyType> foo_out = foo.Out("OUT");
auto& adder = graph.AddNode("FloatAdder");
graph.In("IN").SetName("base") >> foo.In("BASE");
foo.Out("OUT") >> adder.In("IN")[0];
foo.Out("OUT") >> adder.In("IN")[1];
adder.Out("OUT").SetName("out") >> graph.Out("OUT");
foo_out >> adder.In("IN")[0];
foo_out >> adder.In("IN")[1];
Stream<AnyType> out = adder.Out("OUT");
// Graph outputs.
out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -193,12 +221,20 @@ TEST(BuilderTest, FanOut) {
TEST(BuilderTest, TypedMultiple) {
Graph graph;
auto& foo = graph.AddNode<test::Foo>();
auto& adder = graph.AddNode<test::FloatAdder>();
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0];
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1];
adder.Out(MPP_TAG("OUT")).SetName("out") >> graph.Out("OUT");
// Graph inputs.
Stream<AnyType> base = graph.In("IN").SetName("base");
auto& foo = graph.AddNode<Foo>();
base >> foo.In(MPP_TAG("BASE"));
Stream<float> foo_out = foo.Out(MPP_TAG("OUT"));
auto& adder = graph.AddNode<FloatAdder>();
foo_out >> adder.In(MPP_TAG("IN"))[0];
foo_out >> adder.In(MPP_TAG("IN"))[1];
Stream<float> out = adder.Out(MPP_TAG("OUT"));
// Graph outputs.
out.SetName("out") >> graph.Out("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -221,13 +257,20 @@ TEST(BuilderTest, TypedMultiple) {
TEST(BuilderTest, TypedByPorts) {
Graph graph;
auto& foo = graph.AddNode<test::Foo>();
auto& adder = graph.AddNode<FloatAdder>();
// Graph inputs.
Stream<int> base = graph.In(FooBar1::kIn).SetName("base");
graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase];
foo[Foo::kOut] >> adder[FloatAdder::kIn][0];
foo[Foo::kOut] >> adder[FloatAdder::kIn][1];
adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut);
auto& foo = graph.AddNode<test::Foo>();
base >> foo[Foo::kBase];
Stream<float> foo_out = foo[Foo::kOut];
auto& adder = graph.AddNode<FloatAdder>();
foo_out >> adder[FloatAdder::kIn][0];
foo_out >> adder[FloatAdder::kIn][1];
Stream<float> out = adder[FloatAdder::kOut];
// Graph outputs.
out.SetName("out") >> graph.Out(FooBar1::kOut);
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -250,9 +293,15 @@ TEST(BuilderTest, TypedByPorts) {
TEST(BuilderTest, PacketGenerator) {
Graph graph;
// Graph inputs.
SidePacket<AnyType> side_in = graph.SideIn("IN");
auto& generator = graph.AddPacketGenerator("FloatGenerator");
graph.SideIn("IN") >> generator.SideIn("IN");
generator.SideOut("OUT") >> graph.SideOut("OUT");
side_in >> generator.SideIn("IN");
SidePacket<AnyType> side_out = generator.SideOut("OUT");
// Graph outputs.
side_out >> graph.SideOut("OUT");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -269,12 +318,21 @@ TEST(BuilderTest, PacketGenerator) {
TEST(BuilderTest, EmptyTag) {
Graph graph;
// Graph inputs.
Stream<AnyType> a = graph.In("A").SetName("a");
Stream<AnyType> c = graph.In("C").SetName("c");
Stream<AnyType> b = graph.In("B").SetName("b");
auto& foo = graph.AddNode("Foo");
graph.In("A").SetName("a") >> foo.In("")[0];
graph.In("C").SetName("c") >> foo.In("")[2];
graph.In("B").SetName("b") >> foo.In("")[1];
foo.Out("")[0].SetName("x") >> graph.Out("ONE");
foo.Out("")[1].SetName("y") >> graph.Out("TWO");
a >> foo.In("")[0];
c >> foo.In("")[2];
b >> foo.In("")[1];
Stream<AnyType> x = foo.Out("")[0];
Stream<AnyType> y = foo.Out("")[1];
// Graph outputs.
x.SetName("x") >> graph.Out("ONE");
y.SetName("y") >> graph.Out("TWO");
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -301,10 +359,17 @@ TEST(BuilderTest, StringLikeTags) {
constexpr absl::string_view kC = "C";
Graph graph;
// Graph inputs.
Stream<AnyType> a = graph.In(kA).SetName("a");
Stream<AnyType> b = graph.In(kB).SetName("b");
auto& foo = graph.AddNode("Foo");
graph.In(kA).SetName("a") >> foo.In(kA);
graph.In(kB).SetName("b") >> foo.In(kB);
foo.Out(kC).SetName("c") >> graph.Out(kC);
a >> foo.In(kA);
b >> foo.In(kB);
Stream<AnyType> c = foo.Out(kC);
// Graph outputs.
c.SetName("c") >> graph.Out(kC);
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -323,12 +388,21 @@ TEST(BuilderTest, StringLikeTags) {
TEST(BuilderTest, GraphIndexes) {
Graph graph;
// Graph inputs.
Stream<AnyType> a = graph.In(0).SetName("a");
Stream<AnyType> c = graph.In(1).SetName("c");
Stream<AnyType> b = graph.In(2).SetName("b");
auto& foo = graph.AddNode("Foo");
graph.In(0).SetName("a") >> foo.In("")[0];
graph.In(1).SetName("c") >> foo.In("")[2];
graph.In(2).SetName("b") >> foo.In("")[1];
foo.Out("")[0].SetName("x") >> graph.Out(1);
foo.Out("")[1].SetName("y") >> graph.Out(0);
a >> foo.In("")[0];
c >> foo.In("")[2];
b >> foo.In("")[1];
Stream<AnyType> x = foo.Out("")[0];
Stream<AnyType> y = foo.Out("")[1];
// Graph outputs.
x.SetName("x") >> graph.Out(1);
y.SetName("y") >> graph.Out(0);
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
@ -381,21 +455,20 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
int_input >> node[AnyAndSameTypeCalculator::kIntInput];
Stream<AnyType> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput];
any_type_output.SetName("any_type_output");
Stream<AnyType> same_type_output =
node[AnyAndSameTypeCalculator::kSameTypeOutput];
same_type_output.SetName("same_type_output");
Stream<AnyType> recursive_same_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput];
recursive_same_type_output.SetName("recursive_same_type_output");
Stream<int> same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput];
same_int_output.SetName("same_int_output");
Stream<int> recursive_same_int_type_output =
node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput];
any_type_output.SetName("any_type_output");
same_type_output.SetName("same_type_output");
recursive_same_type_output.SetName("recursive_same_type_output");
same_int_output.SetName("same_int_output");
recursive_same_int_type_output.SetName("recursive_same_int_type_output");
CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie<
@ -424,11 +497,10 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
Stream<double> any_type_output =
node[AnyAndSameTypeCalculator::kAnyTypeOutput]
.SetName("any_type_output")
.Cast<double>();
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
any_type_output.SetName("any_type_output") >>
graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
CalculatorGraphConfig expected =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(

View File

@ -486,3 +486,15 @@ cc_test(
"//mediapipe/gpu:disable_gpu": [],
}),
)
cc_library(
name = "frame_buffer",
srcs = ["frame_buffer.cc"],
hdrs = ["frame_buffer.h"],
deps = [
"//mediapipe/framework/port:integral_types",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

View File

@ -0,0 +1,176 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/framework/formats/frame_buffer.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
namespace mediapipe {
namespace {
// Returns whether the input `format` is a supported YUV format.
bool IsSupportedYuvFormat(FrameBuffer::Format format) {
return format == FrameBuffer::Format::kNV21 ||
format == FrameBuffer::Format::kNV12 ||
format == FrameBuffer::Format::kYV12 ||
format == FrameBuffer::Format::kYV21;
}
// Returns supported 1-plane FrameBuffer in YuvData structure.
absl::StatusOr<FrameBuffer::YuvData> GetYuvDataFromOnePlaneFrameBuffer(
const FrameBuffer& source) {
if (!IsSupportedYuvFormat(source.format())) {
return absl::InvalidArgumentError(
"The source FrameBuffer format is not part of YUV420 family.");
}
FrameBuffer::YuvData result;
const int y_buffer_size =
source.plane(0).stride.row_stride_bytes * source.dimension().height;
const int uv_buffer_size =
((source.plane(0).stride.row_stride_bytes + 1) / 2) *
((source.dimension().height + 1) / 2);
result.y_buffer = source.plane(0).buffer;
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
result.uv_row_stride = result.y_row_stride;
if (source.format() == FrameBuffer::Format::kNV21) {
result.v_buffer = result.y_buffer + y_buffer_size;
result.u_buffer = result.v_buffer + 1;
result.uv_pixel_stride = 2;
// If y_row_stride equals to the frame width and is an odd value,
// uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride.
if (result.y_row_stride == source.dimension().width &&
result.y_row_stride % 2 == 1) {
result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2;
}
} else if (source.format() == FrameBuffer::Format::kNV12) {
result.u_buffer = result.y_buffer + y_buffer_size;
result.v_buffer = result.u_buffer + 1;
result.uv_pixel_stride = 2;
// If y_row_stride equals to the frame width and is an odd value,
// uv_row_stride = y_row_stride + 1, otherwise uv_row_stride = y_row_stride.
if (result.y_row_stride == source.dimension().width &&
result.y_row_stride % 2 == 1) {
result.uv_row_stride = (result.y_row_stride + 1) / 2 * 2;
}
} else if (source.format() == FrameBuffer::Format::kYV21) {
result.u_buffer = result.y_buffer + y_buffer_size;
result.v_buffer = result.u_buffer + uv_buffer_size;
result.uv_pixel_stride = 1;
result.uv_row_stride = (result.y_row_stride + 1) / 2;
} else if (source.format() == FrameBuffer::Format::kYV12) {
result.v_buffer = result.y_buffer + y_buffer_size;
result.u_buffer = result.v_buffer + uv_buffer_size;
result.uv_pixel_stride = 1;
result.uv_row_stride = (result.y_row_stride + 1) / 2;
}
return result;
}
// Returns supported 2-plane FrameBuffer in YuvData structure.
absl::StatusOr<FrameBuffer::YuvData> GetYuvDataFromTwoPlaneFrameBuffer(
const FrameBuffer& source) {
if (source.format() != FrameBuffer::Format::kNV12 &&
source.format() != FrameBuffer::Format::kNV21) {
return absl::InvalidArgumentError("Unsupported YUV planar format.");
}
FrameBuffer::YuvData result;
// Y plane
result.y_buffer = source.plane(0).buffer;
// All plane strides
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
result.uv_pixel_stride = 2;
if (source.format() == FrameBuffer::Format::kNV12) {
// Y and UV interleaved format
result.u_buffer = source.plane(1).buffer;
result.v_buffer = result.u_buffer + 1;
} else {
// Y and VU interleaved format
result.v_buffer = source.plane(1).buffer;
result.u_buffer = result.v_buffer + 1;
}
return result;
}
// Returns supported 3-plane FrameBuffer in YuvData structure. Note that NV21
// and NV12 are included in the supported Yuv formats. Technically, NV21 and
// NV12 should not be described by the 3-plane format. Historically, NV21 is
// used loosely such that it can also be used to describe YV21 format. For
// backwards compatibility, FrameBuffer supports NV21/NV12 with 3-plane format
// but such usage is discouraged
absl::StatusOr<FrameBuffer::YuvData> GetYuvDataFromThreePlaneFrameBuffer(
const FrameBuffer& source) {
if (!IsSupportedYuvFormat(source.format())) {
return absl::InvalidArgumentError(
"The source FrameBuffer format is not part of YUV420 family.");
}
if (source.plane(1).stride.row_stride_bytes !=
source.plane(2).stride.row_stride_bytes ||
source.plane(1).stride.pixel_stride_bytes !=
source.plane(2).stride.pixel_stride_bytes) {
return absl::InternalError("Unsupported YUV planar format.");
}
FrameBuffer::YuvData result;
if (source.format() == FrameBuffer::Format::kNV21 ||
source.format() == FrameBuffer::Format::kYV12) {
// Y follow by VU order. The VU chroma planes can be interleaved or
// planar.
result.y_buffer = source.plane(0).buffer;
result.v_buffer = source.plane(1).buffer;
result.u_buffer = source.plane(2).buffer;
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes;
} else {
// Y follow by UV order. The UV chroma planes can be interleaved or
// planar.
result.y_buffer = source.plane(0).buffer;
result.u_buffer = source.plane(1).buffer;
result.v_buffer = source.plane(2).buffer;
result.y_row_stride = source.plane(0).stride.row_stride_bytes;
result.uv_row_stride = source.plane(1).stride.row_stride_bytes;
result.uv_pixel_stride = source.plane(1).stride.pixel_stride_bytes;
}
return result;
}
} // namespace
absl::StatusOr<FrameBuffer::YuvData> FrameBuffer::GetYuvDataFromFrameBuffer(
const FrameBuffer& source) {
if (!IsSupportedYuvFormat(source.format())) {
return absl::InvalidArgumentError(
"The source FrameBuffer format is not part of YUV420 family.");
}
if (source.plane_count() == 1) {
return GetYuvDataFromOnePlaneFrameBuffer(source);
} else if (source.plane_count() == 2) {
return GetYuvDataFromTwoPlaneFrameBuffer(source);
} else if (source.plane_count() == 3) {
return GetYuvDataFromThreePlaneFrameBuffer(source);
}
return absl::InvalidArgumentError(
"The source FrameBuffer must be consisted by 1, 2, or 3 planes");
}
} // namespace mediapipe

View File

@ -0,0 +1,246 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_
#define MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe {
// A `FrameBuffer` provides a view into the provided backing buffer (e.g. camera
// frame or still image) with buffer format information. FrameBuffer doesn't
// take ownership of the provided backing buffer. The caller is responsible to
// manage the backing buffer lifecycle for the lifetime of the FrameBuffer.
//
// Examples:
//
// // Create an metadata instance with no backing buffer.
// auto buffer = FrameBuffer::Create(/*planes=*/{}, dimension, kRGBA,
// KTopLeft);
//
// // Create an RGBA instance with backing buffer on single plane.
// FrameBuffer::Plane plane =
// {rgba_buffer, /*stride=*/{dimension.width * 4, 4}};
// auto buffer = FrameBuffer::Create({plane}, dimension, kRGBA, kTopLeft);
//
// // Create an YUV instance with planar backing buffer.
// FrameBuffer::Plane y_plane = {y_buffer, /*stride=*/{dimension.width , 1}};
// FrameBuffer::Plane uv_plane = {u_buffer, /*stride=*/{dimension.width, 2}};
// auto buffer = FrameBuffer::Create({y_plane, uv_plane}, dimension, kNV21,
// kLeftTop);
class FrameBuffer {
public:
// Colorspace formats.
enum class Format {
kRGBA,
kRGB,
kNV12,
kNV21,
kYV12,
kYV21,
kGRAY,
kUNKNOWN
};
// Stride information.
struct Stride {
// The row stride in bytes. This is the distance between the start pixels of
// two consecutive rows in the image.
int row_stride_bytes;
// This is the distance between two consecutive pixel values in a row of
// pixels in bytes. It may be larger than the size of a single pixel to
// account for interleaved image data or padded formats.
int pixel_stride_bytes;
bool operator==(const Stride& other) const {
return row_stride_bytes == other.row_stride_bytes &&
pixel_stride_bytes == other.pixel_stride_bytes;
}
bool operator!=(const Stride& other) const { return !operator==(other); }
};
// YUV data structure.
struct YuvData {
const uint8* y_buffer;
const uint8* u_buffer;
const uint8* v_buffer;
// Y buffer row stride in bytes.
int y_row_stride;
// U/V buffer row stride in bytes.
int uv_row_stride;
// U/V pixel stride in bytes. This is the distance between two consecutive
// u/v pixel values in a row.
int uv_pixel_stride;
};
// FrameBuffer content orientation follows EXIF specification. The name of
// each enum value defines the position of the 0th row and the 0th column of
// the image content. See http://jpegclub.org/exif_orientation.html for
// details.
enum class Orientation {
kTopLeft = 1,
kTopRight = 2,
kBottomRight = 3,
kBottomLeft = 4,
kLeftTop = 5,
kRightTop = 6,
kRightBottom = 7,
kLeftBottom = 8
};
// Plane encapsulates buffer and stride information.
struct Plane {
const uint8* buffer;
Stride stride;
};
// Dimension information for the whole frame or a cropped portion of it.
struct Dimension {
// The width dimension in pixel unit.
int width;
// The height dimension in pixel unit.
int height;
bool operator==(const Dimension& other) const {
return width == other.width && height == other.height;
}
bool operator!=(const Dimension& other) const {
return width != other.width || height != other.height;
}
bool operator>=(const Dimension& other) const {
return width >= other.width && height >= other.height;
}
bool operator<=(const Dimension& other) const {
return width <= other.width && height <= other.height;
}
// Swaps width and height.
void Swap() {
using std::swap;
swap(width, height);
}
// Returns area represented by width * height.
int Size() const { return width * height; }
};
// Factory method for creating a FrameBuffer object from row-major backing
// buffers.
static std::unique_ptr<FrameBuffer> Create(const std::vector<Plane>& planes,
Dimension dimension, Format format,
Orientation orientation) {
return absl::make_unique<FrameBuffer>(planes, dimension, format,
orientation);
}
// Factory method for creating a FrameBuffer object from row-major movable
// backing buffers.
static std::unique_ptr<FrameBuffer> Create(std::vector<Plane>&& planes,
Dimension dimension, Format format,
Orientation orientation) {
return absl::make_unique<FrameBuffer>(std::move(planes), dimension, format,
orientation);
}
// Returns YuvData which contains the Y, U, and V buffer and their
// stride info from the input `source` FrameBuffer which is in the YUV family
// formats (e.g NV12, NV21, YV12, and YV21).
static absl::StatusOr<YuvData> GetYuvDataFromFrameBuffer(
const FrameBuffer& source);
// Builds a FrameBuffer object from a row-major backing buffer.
//
// The FrameBuffer does not take ownership of the backing buffer. The backing
// buffer is read-only and the caller is responsible for maintaining the
// backing buffer lifecycle for the lifetime of FrameBuffer.
FrameBuffer(const std::vector<Plane>& planes, Dimension dimension,
Format format, Orientation orientation)
: planes_(planes),
dimension_(dimension),
format_(format),
orientation_(orientation) {}
// Builds a FrameBuffer object from a movable row-major backing buffer.
//
// The FrameBuffer does not take ownership of the backing buffer. The backing
// buffer is read-only and the caller is responsible for maintaining the
// backing buffer lifecycle for the lifetime of FrameBuffer.
FrameBuffer(std::vector<Plane>&& planes, Dimension dimension, Format format,
Orientation orientation)
: planes_(std::move(planes)),
dimension_(dimension),
format_(format),
orientation_(orientation) {}
// Copy constructor.
//
// FrameBuffer does not take ownership of the backing buffer. The copy
// constructor behaves the same way to only copy the buffer pointer and not
// take ownership of the backing buffer.
FrameBuffer(const FrameBuffer& frame_buffer) {
planes_.clear();
for (int i = 0; i < frame_buffer.plane_count(); i++) {
planes_.push_back(
FrameBuffer::Plane{.buffer = frame_buffer.plane(i).buffer,
.stride = frame_buffer.plane(i).stride});
}
dimension_ = frame_buffer.dimension();
format_ = frame_buffer.format();
orientation_ = frame_buffer.orientation();
}
// Returns number of planes.
int plane_count() const { return planes_.size(); }
// Returns plane indexed by the input `index`.
Plane plane(int index) const {
if (index > -1 && static_cast<size_t>(index) < planes_.size()) {
return planes_[index];
}
return {};
}
// Returns FrameBuffer dimension.
Dimension dimension() const { return dimension_; }
// Returns FrameBuffer format.
Format format() const { return format_; }
// Returns FrameBuffer orientation.
Orientation orientation() const { return orientation_; }
private:
std::vector<Plane> planes_;
Dimension dimension_;
Format format_;
Orientation orientation_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_FRAME_BUFFER_H_

View File

@ -92,7 +92,7 @@ bool GraphRegistry::IsRegistered(const std::string& ns,
}
absl::StatusOr<CalculatorGraphConfig> GraphRegistry::CreateByName(
const std::string& ns, const std::string& type_name,
absl::string_view ns, absl::string_view type_name,
SubgraphContext* context) const {
absl::StatusOr<std::unique_ptr<Subgraph>> maker =
local_factories_.IsRegistered(ns, type_name)

View File

@ -20,6 +20,7 @@
#include "absl/base/macros.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/deps/registration.h"
@ -187,7 +188,7 @@ class GraphRegistry {
// Returns the specified graph config.
absl::StatusOr<CalculatorGraphConfig> CreateByName(
const std::string& ns, const std::string& type_name,
absl::string_view ns, absl::string_view type_name,
SubgraphContext* context = nullptr) const;
static GraphRegistry global_graph_registry;

View File

@ -441,6 +441,21 @@ cc_library(
],
)
cc_library(
name = "gpu_buffer_storage_yuv_image",
srcs = ["gpu_buffer_storage_yuv_image.cc"],
hdrs = ["gpu_buffer_storage_yuv_image.h"],
visibility = ["//visibility:public"],
deps = [
":gpu_buffer_format",
":gpu_buffer_storage",
"//mediapipe/framework/formats:yuv_image",
"//third_party/libyuv",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
],
)
cc_library(
name = "gpu_buffer_storage_ahwb",
srcs = ["gpu_buffer_storage_ahwb.cc"],
@ -1187,3 +1202,17 @@ mediapipe_cc_test(
"//mediapipe/framework/port:gtest_main",
],
)
mediapipe_cc_test(
name = "gpu_buffer_storage_yuv_image_test",
size = "small",
srcs = ["gpu_buffer_storage_yuv_image_test.cc"],
exclude_platforms = [
"ios",
],
deps = [
":gpu_buffer_format",
":gpu_buffer_storage_yuv_image",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -212,6 +212,10 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) {
case GpuBufferFormat::kTwoComponentHalf16:
case GpuBufferFormat::kRGBAHalf64:
case GpuBufferFormat::kRGBAFloat128:
case GpuBufferFormat::kNV12:
case GpuBufferFormat::kNV21:
case GpuBufferFormat::kI420:
case GpuBufferFormat::kYV12:
case GpuBufferFormat::kUnknown:
return ImageFormat::UNKNOWN;
}

View File

@ -52,6 +52,14 @@ enum class GpuBufferFormat : uint32_t {
kRGB24 = 0x00000018, // Note: prefer BGRA32 whenever possible.
kRGBAHalf64 = MEDIAPIPE_FOURCC('R', 'G', 'h', 'A'),
kRGBAFloat128 = MEDIAPIPE_FOURCC('R', 'G', 'f', 'A'),
// 8-bit Y plane + interleaved 8-bit U/V plane with 2x2 subsampling.
kNV12 = MEDIAPIPE_FOURCC('N', 'V', '1', '2'),
// 8-bit Y plane + interleaved 8-bit V/U plane with 2x2 subsampling.
kNV21 = MEDIAPIPE_FOURCC('N', 'V', '2', '1'),
// 8-bit Y plane + non-interleaved 8-bit U/V planes with 2x2 subsampling.
kI420 = MEDIAPIPE_FOURCC('I', '4', '2', '0'),
// 8-bit Y plane + non-interleaved 8-bit V/U planes with 2x2 subsampling.
kYV12 = MEDIAPIPE_FOURCC('Y', 'V', '1', '2'),
};
#if !MEDIAPIPE_DISABLE_GPU
@ -111,6 +119,10 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) {
return kCVPixelFormatType_64RGBAHalf;
case GpuBufferFormat::kRGBAFloat128:
return kCVPixelFormatType_128RGBAFloat;
case GpuBufferFormat::kNV12:
case GpuBufferFormat::kNV21:
case GpuBufferFormat::kI420:
case GpuBufferFormat::kYV12:
case GpuBufferFormat::kUnknown:
return -1;
}

View File

@ -158,7 +158,7 @@ public class GraphTextureFrame implements TextureFrame {
@Override
protected void finalize() throws Throwable {
if (refCount >= 0 || nativeBufferHandle != 0) {
if (refCount > 0 || nativeBufferHandle != 0) {
logger.atWarning().log("release was not called before finalize");
}
if (!activeConsumerContextHandleSet.isEmpty()) {

View File

@ -199,6 +199,28 @@ public final class PacketGetter {
return nativeGetImageData(packet.getNativeHandle(), buffer);
}
/** Returns the size of Image list. This helps to determine size of allocated ByteBuffer array. */
public static int getImageListSize(final Packet packet) {
return nativeGetImageListSize(packet.getNativeHandle());
}
/**
* Assign the native image buffer array in given ByteBuffer array. It assumes given ByteBuffer
* array has the the same size of image list packet, and assumes the output buffer stores pixels
* contiguously. It returns false if this assumption does not hold.
*
* <p>If deepCopy is true, it assumes the given buffersArray has allocated the required size of
* ByteBuffer to copy image data to. If false, the ByteBuffer will wrap the memory address of
* MediaPipe ImageFrame of graph output, and the ByteBuffer data is available only when MediaPipe
* graph is alive.
*
* <p>Note: this function does not assume the pixel format.
*/
public static boolean getImageList(
final Packet packet, ByteBuffer[] buffersArray, boolean deepCopy) {
return nativeGetImageList(packet.getNativeHandle(), buffersArray, deepCopy);
}
/**
* Converts an RGB mediapipe image frame packet to an RGBA Byte buffer.
*
@ -316,7 +338,8 @@ public final class PacketGetter {
public static GraphTextureFrame getTextureFrameDeferredSync(final Packet packet) {
return new GraphTextureFrame(
nativeGetGpuBuffer(packet.getNativeHandle(), /* waitOnCpu= */ false),
packet.getTimestamp(), /* deferredSync= */true);
packet.getTimestamp(),
/* deferredSync= */ true);
}
private static native long nativeGetPacketFromReference(long nativePacketHandle);
@ -363,6 +386,11 @@ public final class PacketGetter {
private static native boolean nativeGetImageData(long nativePacketHandle, ByteBuffer buffer);
private static native int nativeGetImageListSize(long nativePacketHandle);
private static native boolean nativeGetImageList(
long nativePacketHandle, ByteBuffer[] bufferArray, boolean deepCopy);
private static native boolean nativeGetRgbaFromRgb(long nativePacketHandle, ByteBuffer buffer);
// Retrieves the values that are in the VideoHeader.
private static native int nativeGetVideoHeaderWidth(long nativepackethandle);

View File

@ -50,7 +50,10 @@ public class ByteBufferExtractor {
switch (container.getImageProperties().getStorageType()) {
case MPImage.STORAGE_TYPE_BYTEBUFFER:
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
return byteBufferImageContainer
.getByteBuffer()
.asReadOnlyBuffer()
.order(ByteOrder.nativeOrder());
default:
throw new IllegalArgumentException(
"Extract ByteBuffer from a MPImage created by objects other than Bytebuffer is not"
@ -74,7 +77,7 @@ public class ByteBufferExtractor {
* @throws IllegalArgumentException when the extraction requires unsupported format or data type
* conversions.
*/
static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
public static ByteBuffer extract(MPImage image, @MPImageFormat int targetFormat) {
MPImageContainer container;
MPImageProperties byteBufferProperties =
MPImageProperties.builder()
@ -83,12 +86,16 @@ public class ByteBufferExtractor {
.build();
if ((container = image.getContainer(byteBufferProperties)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer();
return byteBufferImageContainer
.getByteBuffer()
.asReadOnlyBuffer()
.order(ByteOrder.nativeOrder());
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BYTEBUFFER)) != null) {
ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container;
@MPImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat();
return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat)
.asReadOnlyBuffer();
.asReadOnlyBuffer()
.order(ByteOrder.nativeOrder());
} else if ((container = image.getContainer(MPImage.STORAGE_TYPE_BITMAP)) != null) {
BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container;
ByteBuffer byteBuffer =

View File

@ -67,6 +67,8 @@ public class MPImage implements Closeable {
IMAGE_FORMAT_YUV_420_888,
IMAGE_FORMAT_ALPHA,
IMAGE_FORMAT_JPEG,
IMAGE_FORMAT_VEC32F1,
IMAGE_FORMAT_VEC32F2,
})
@Retention(RetentionPolicy.SOURCE)
public @interface MPImageFormat {}
@ -81,6 +83,8 @@ public class MPImage implements Closeable {
public static final int IMAGE_FORMAT_YUV_420_888 = 7;
public static final int IMAGE_FORMAT_ALPHA = 8;
public static final int IMAGE_FORMAT_JPEG = 9;
public static final int IMAGE_FORMAT_VEC32F1 = 10;
public static final int IMAGE_FORMAT_VEC32F2 = 11;
/** Specifies the image container type. Would be useful for choosing extractors. */
@IntDef({

View File

@ -14,6 +14,7 @@
#include "mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/formats/image.h"
@ -39,6 +40,52 @@ template <typename T>
const T& GetFromNativeHandle(int64_t packet_handle) {
return mediapipe::android::Graph::GetPacketFromHandle(packet_handle).Get<T>();
}
bool CopyImageDataToByteBuffer(JNIEnv* env, const mediapipe::ImageFrame& image,
jobject byte_buffer) {
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
ThrowIfError(env, absl::InvalidArgumentError(
"input buffer does not support direct access"));
return false;
}
// Assume byte buffer stores pixel data contiguously.
const int expected_buffer_size = image.Width() * image.Height() *
image.ByteDepth() * image.NumberOfChannels();
if (buffer_size != expected_buffer_size) {
ThrowIfError(
env, absl::InvalidArgumentError(absl::StrCat(
"Expected buffer size ", expected_buffer_size,
" got: ", buffer_size, ", width ", image.Width(), ", height ",
image.Height(), ", channels ", image.NumberOfChannels())));
return false;
}
switch (image.ByteDepth()) {
case 1: {
uint8* data = static_cast<uint8*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
case 2: {
uint16* data = static_cast<uint16*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
case 4: {
float* data = static_cast<float*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
default: {
return false;
}
}
return true;
}
} // namespace
JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetPacketFromReference)(
@ -298,45 +345,50 @@ JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
.GetImageFrameSharedPtr()
.get()
: GetFromNativeHandle<mediapipe::ImageFrame>(packet);
int64_t buffer_size = env->GetDirectBufferCapacity(byte_buffer);
void* buffer_data = env->GetDirectBufferAddress(byte_buffer);
if (buffer_data == nullptr || buffer_size < 0) {
ThrowIfError(env, absl::InvalidArgumentError(
"input buffer does not support direct access"));
return false;
return CopyImageDataToByteBuffer(env, image, byte_buffer);
}
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)(
JNIEnv* env, jobject thiz, jlong packet) {
const auto& image_list =
GetFromNativeHandle<std::vector<mediapipe::Image>>(packet);
return image_list.size();
}
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array,
jboolean deep_copy) {
const auto& image_list =
GetFromNativeHandle<std::vector<mediapipe::Image>>(packet);
if (env->GetArrayLength(byte_buffer_array) != image_list.size()) {
ThrowIfError(env, absl::InvalidArgumentError(absl::StrCat(
"Expected ByteBuffer array size: ", image_list.size(),
" but get ByteBuffer array size: ",
env->GetArrayLength(byte_buffer_array))));
return false;
}
for (int i = 0; i < image_list.size(); ++i) {
auto& image = *image_list[i].GetImageFrameSharedPtr().get();
if (!image.IsContiguous()) {
ThrowIfError(
env, absl::InternalError("ImageFrame must store data contiguously to "
"be allocated as ByteBuffer."));
return false;
}
if (deep_copy) {
jobject byte_buffer = reinterpret_cast<jobject>(
env->GetObjectArrayElement(byte_buffer_array, i));
if (!CopyImageDataToByteBuffer(env, image, byte_buffer)) {
return false;
}
} else {
// Assume byte buffer stores pixel data contiguously.
const int expected_buffer_size = image.Width() * image.Height() *
image.ByteDepth() * image.NumberOfChannels();
if (buffer_size != expected_buffer_size) {
ThrowIfError(
env, absl::InvalidArgumentError(absl::StrCat(
"Expected buffer size ", expected_buffer_size,
" got: ", buffer_size, ", width ", image.Width(), ", height ",
image.Height(), ", channels ", image.NumberOfChannels())));
return false;
}
switch (image.ByteDepth()) {
case 1: {
uint8* data = static_cast<uint8*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
case 2: {
uint16* data = static_cast<uint16*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
case 4: {
float* data = static_cast<float*>(buffer_data);
image.CopyToBuffer(data, expected_buffer_size);
break;
}
default: {
return false;
image.ByteDepth() *
image.NumberOfChannels();
jobject image_data_byte_buffer = env->NewDirectByteBuffer(
image.MutablePixelData(), expected_buffer_size);
env->SetObjectArrayElement(byte_buffer_array, i, image_data_byte_buffer);
}
}
return true;
@ -415,7 +467,8 @@ JNIEXPORT jbyteArray JNICALL PACKET_GETTER_METHOD(nativeGetAudioData)(
int16 value =
static_cast<int16>(audio_mat(channel, sample) * kMultiplier);
// The java and native has the same byte order, by default is little
// Endian, we can safely copy data directly, we have tests to cover this.
// Endian, we can safely copy data directly, we have tests to cover
// this.
env->SetByteArrayRegion(byte_data, offset, 2,
reinterpret_cast<const jbyte*>(&value));
offset += 2;

View File

@ -106,6 +106,17 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageHeight)(JNIEnv* env,
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageData)(
JNIEnv* env, jobject thiz, jlong packet, jobject byte_buffer);
// Return the vector size of std::vector<Image>.
JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetImageListSize)(
JNIEnv* env, jobject thiz, jlong packet);
// Fill ByteBuffer[] from the Packet of std::vector<Image>.
// Before calling this, the byte_buffer_array needs to have the correct
// allocated size.
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetImageList)(
JNIEnv* env, jobject thiz, jlong packet, jobjectArray byte_buffer_array,
jboolean deep_copy);
// Before calling this, the byte_buffer needs to have the correct allocated
// size.
JNIEXPORT jboolean JNICALL PACKET_GETTER_METHOD(nativeGetRgbaFromRgb)(

View File

@ -61,6 +61,7 @@ py_test(
name = "file_util_test",
srcs = ["file_util_test.py"],
data = ["//mediapipe/model_maker/python/core/utils/testdata"],
tags = ["requires-net:external"],
deps = [":file_util"],
)

View File

@ -13,11 +13,93 @@
# limitations under the License.
"""Utilities for files."""
import dataclasses
import os
import pathlib
import shutil
import tarfile
import tempfile
import requests
# resources dependency
_TEMPDIR_FOLDER = 'model_maker'
@dataclasses.dataclass
class DownloadedFiles:
"""File(s) that are downloaded from a url into a local directory.
If `is_folder` is True:
1. `path` should be a folder
2. `url` should point to a .tar.gz file which contains a single folder at
the root level.
Attributes:
path: Relative path in local directory.
url: GCS url to download the file(s).
is_folder: Whether the path and url represents a folder.
"""
path: str
url: str
is_folder: bool = False
def get_path(self) -> str:
"""Gets the path of files saved in a local directory.
If the path doesn't exist, this method will download the file(s) from the
provided url. The path is not cleaned up so it can be reused for subsequent
calls to the same path.
Folders are expected to be zipped in a .tar.gz file which will be extracted
into self.path in the local directory.
Raises:
RuntimeError: If the extracted folder does not have a singular root
directory.
Returns:
The absolute path to the downloaded file(s)
"""
tmpdir = tempfile.gettempdir()
absolute_path = pathlib.Path(
os.path.join(tmpdir, _TEMPDIR_FOLDER, self.path)
)
if not absolute_path.exists():
print(f'Downloading {self.url} to {absolute_path}')
r = requests.get(self.url, allow_redirects=True)
if self.is_folder:
# Use tempf to store the downloaded .tar.gz file
tempf = tempfile.NamedTemporaryFile(suffix='.tar.gz', mode='wb')
tempf.write(r.content)
tarf = tarfile.open(tempf.name)
# Use tmpdir to store the extracted contents of the .tar.gz file
with tempfile.TemporaryDirectory() as tmpdir:
tarf.extractall(tmpdir)
tarf.close()
tempf.close()
subdirs = os.listdir(tmpdir)
# Make sure tmpdir only has one subdirectory
if len(subdirs) > 1 or not os.path.isdir(
os.path.join(tmpdir, subdirs[0])
):
raise RuntimeError(
f"Extracted folder from {self.url} doesn't contain a "
f'single root directory: {subdirs}'
)
# Create the parent dir of absolute_path and copy the contents of the
# top level dir in the .tar.gz file into absolute_path.
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
shutil.copytree(os.path.join(tmpdir, subdirs[0]), absolute_path)
else:
pathlib.Path.mkdir(absolute_path.parent, parents=True, exist_ok=True)
with open(absolute_path, 'wb') as f:
f.write(r.content)
return str(absolute_path)
# TODO Remove after text_classifier supports downloading on demand.
def get_absolute_path(file_path: str) -> str:
"""Gets the absolute path of a file in the model_maker directory.

View File

@ -12,13 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
from unittest import mock as unittest_mock
from absl.testing import absltest
import requests
from mediapipe.model_maker.python.core.utils import file_util
class FileUtilTest(absltest.TestCase):
def setUp(self):
super().setUp()
mock_gettempdir = unittest_mock.patch.object(
tempfile,
'gettempdir',
return_value=self.create_tempdir(),
autospec=True,
)
self.mock_gettempdir = mock_gettempdir.start()
self.addCleanup(mock_gettempdir.stop)
def test_get_path(self):
path = 'gesture_recognizer/hand_landmark_full.tflite'
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
model_path = downloaded_files.get_path()
self.assertTrue(os.path.exists(model_path))
self.assertGreater(os.path.getsize(model_path), 0)
def test_get_path_folder(self):
folder_contents = [
'keras_metadata.pb',
'saved_model.pb',
'assets/vocab.txt',
'variables/variables.data-00000-of-00001',
'variables/variables.index',
]
path = 'text_classifier/mobilebert_tiny'
url = (
'https://storage.googleapis.com/mediapipe-assets/mobilebert_tiny.tar.gz'
)
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=True)
model_path = downloaded_files.get_path()
self.assertTrue(os.path.exists(model_path))
for file_name in folder_contents:
file_path = os.path.join(model_path, file_name)
self.assertTrue(os.path.exists(file_path))
self.assertGreater(os.path.getsize(file_path), 0)
@unittest_mock.patch.object(requests, 'get', wraps=requests.get)
def test_get_path_multiple_calls(self, mock_get):
path = 'gesture_recognizer/hand_landmark_full.tflite'
url = 'https://storage.googleapis.com/mediapipe-assets/hand_landmark_full.tflite'
downloaded_files = file_util.DownloadedFiles(path, url, is_folder=False)
model_path = downloaded_files.get_path()
self.assertTrue(os.path.exists(model_path))
self.assertGreater(os.path.getsize(model_path), 0)
model_path_2 = downloaded_files.get_path()
self.assertEqual(model_path, model_path_2)
self.assertEqual(mock_get.call_count, 1)
def test_get_absolute_path(self):
test_file = 'mediapipe/model_maker/python/core/utils/testdata/test.txt'
absolute_path = file_util.get_absolute_path(test_file)

View File

@ -48,7 +48,6 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
@ -90,7 +89,6 @@ cc_library(
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedder_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:embedding_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"@com_google_absl//absl/status",

View File

@ -40,7 +40,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "mediapipe/tasks/metadata/metadata_schema_generated.h"
@ -68,7 +67,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit;
using ::tflite::TensorMetadata;
using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
@ -455,12 +454,13 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
}
// If output tensors are quantized, they must be dequantized first.
TensorsSource dequantized_tensors(&tensors_in);
TensorsSource dequantized_tensors = tensors_in;
if (options.has_quantized_outputs()) {
GenericNode* tensors_dequantization_node =
&graph.AddNode("TensorsDequantizationCalculator");
tensors_in >> tensors_dequantization_node->In(kTensorsTag);
dequantized_tensors = {tensors_dequantization_node, kTensorsTag};
dequantized_tensors = tensors_dequantization_node->Out(kTensorsTag)
.Cast<std::vector<Tensor>>();
}
// If there are multiple classification heads, the output tensors need to be
@ -477,7 +477,8 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
auto* range = split_tensor_vector_options.add_ranges();
range->set_begin(i);
range->set_end(i + 1);
split_tensors.emplace_back(split_tensor_vector_node, i);
split_tensors.push_back(
split_tensor_vector_node->Out(i).Cast<std::vector<Tensor>>());
}
dequantized_tensors >> split_tensor_vector_node->In(0);
} else {
@ -494,8 +495,9 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>()
.CopyFrom(options.score_calibration_options().at(i));
split_tensors[i] >> score_calibration_node->In(kScoresTag);
calibrated_tensors.emplace_back(score_calibration_node,
kCalibratedScoresTag);
calibrated_tensors.push_back(
score_calibration_node->Out(kCalibratedScoresTag)
.Cast<std::vector<Tensor>>());
} else {
calibrated_tensors.emplace_back(split_tensors[i]);
}

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedder_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
#include "tensorflow/lite/schema/schema_generated.h"
@ -51,8 +50,6 @@ using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::EmbeddingResult;
using ::mediapipe::tasks::core::ModelResources;
using TensorsSource =
::mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>;
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kEmbeddingsTag[] = "EMBEDDINGS";
@ -229,12 +226,13 @@ class EmbeddingPostprocessingGraph : public mediapipe::Subgraph {
Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
// If output tensors are quantized, they must be dequantized first.
TensorsSource dequantized_tensors(&tensors_in);
Source<std::vector<Tensor>> dequantized_tensors = tensors_in;
if (options.has_quantized_outputs()) {
GenericNode& tensors_dequantization_node =
graph.AddNode("TensorsDequantizationCalculator");
tensors_in >> tensors_dequantization_node.In(kTensorsTag);
dequantized_tensors = {&tensors_dequantization_node, kTensorsTag};
dequantized_tensors = tensors_dequantization_node.Out(kTensorsTag)
.Cast<std::vector<Tensor>>();
}
// Adds TensorsToEmbeddingsCalculator.

View File

@ -14,12 +14,6 @@
package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library(
name = "source_or_node_output",
hdrs = ["source_or_node_output.h"],
deps = ["//mediapipe/framework/api2:builder"],
)
cc_library(
name = "cosine_similarity",
srcs = ["cosine_similarity.cc"],

View File

@ -1,66 +0,0 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_
#include "mediapipe/framework/api2/builder.h"
namespace mediapipe {
namespace tasks {
// Helper class representing either a Source object or a GenericNode output.
//
// Source and MultiSource (the output of a GenericNode) are widely incompatible,
// but being able to represent either of these in temporary variables and
// connect them later on facilitates graph building.
template <typename T>
class SourceOrNodeOutput {
public:
SourceOrNodeOutput() = delete;
// The caller is responsible for ensuring 'source' outlives this object.
explicit SourceOrNodeOutput(mediapipe::api2::builder::Source<T>* source)
: source_(source) {}
// The caller is responsible for ensuring 'node' outlives this object.
SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node,
std::string tag)
: node_(node), tag_(tag) {}
// The caller is responsible for ensuring 'node' outlives this object.
SourceOrNodeOutput(mediapipe::api2::builder::GenericNode* node, int index)
: node_(node), index_(index) {}
// Connects the source or node output to the provided destination.
template <typename U>
void operator>>(const U& dest) {
if (source_ != nullptr) {
*source_ >> dest;
} else {
if (index_ < 0) {
node_->Out(tag_) >> dest;
} else {
node_->Out(index_) >> dest;
}
}
}
private:
mediapipe::api2::builder::Source<T>* source_ = nullptr;
mediapipe::api2::builder::GenericNode* node_ = nullptr;
std::string tag_ = "";
int index_ = -1;
};
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_SOURCE_OR_NODE_OUTPUT_H_

View File

@ -139,5 +139,32 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kMobileBert);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("When you go to this restaurant, they hold the "
"pancake upside-down before they hand it "
"to you. It's a great gimmick."));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result1,
text_embedder->Embed(
"Let's make a plan to steal the declaration of independence."));
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
// TODO: The similarity should likely be lower
EXPECT_NEAR(similarity, 0.98088, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace
} // namespace mediapipe::tasks::text::text_embedder

View File

@ -140,6 +140,7 @@ cc_library(
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
@ -68,6 +69,9 @@ constexpr char kHandednessTag[] = "HANDEDNESS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kHandGesturesTag[] = "HAND_GESTURES";
constexpr char kHandTrackingIdsTag[] = "HAND_TRACKING_IDS";
constexpr char kRectNextFrameTag[] = "HAND_RECT_NEXT_FRAME";
constexpr char kPalmRectsTag[] = "PALM_RECTS";
constexpr char kPalmDetectionsTag[] = "PALM_DETECTIONS";
constexpr char kHandLandmarkerBundleAssetName[] = "hand_landmarker.task";
constexpr char kHandGestureRecognizerBundleAssetName[] =
"hand_gesture_recognizer.task";
@ -77,6 +81,9 @@ struct GestureRecognizerOutputs {
Source<std::vector<ClassificationList>> handedness;
Source<std::vector<NormalizedLandmarkList>> hand_landmarks;
Source<std::vector<LandmarkList>> hand_world_landmarks;
Source<std::vector<NormalizedRect>> hand_rects_next_frame;
Source<std::vector<NormalizedRect>> palm_rects;
Source<std::vector<Detection>> palm_detections;
Source<Image> image;
};
@ -135,9 +142,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// Inputs:
// IMAGE - Image
// Image to perform hand gesture recognition on.
// NORM_RECT - NormalizedRect
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform landmarks
// detection on.
// detection on. If not provided, whole image is used for gesture
// recognition.
//
// Outputs:
// HAND_GESTURES - std::vector<ClassificationList>
@ -208,11 +216,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(auto hand_gesture_recognition_output,
ASSIGN_OR_RETURN(
auto hand_gesture_recognition_output,
BuildGestureRecognizerGraph(
*sc->MutableOptions<GestureRecognizerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_gesture_recognition_output.gesture >>
graph[Output<std::vector<ClassificationList>>(kHandGesturesTag)];
hand_gesture_recognition_output.handedness >>
@ -222,6 +231,12 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
hand_gesture_recognition_output.hand_world_landmarks >>
graph[Output<std::vector<LandmarkList>>(kWorldLandmarksTag)];
hand_gesture_recognition_output.image >> graph[Output<Image>(kImageTag)];
hand_gesture_recognition_output.hand_rects_next_frame >>
graph[Output<std::vector<NormalizedRect>>(kRectNextFrameTag)];
hand_gesture_recognition_output.palm_rects >>
graph[Output<std::vector<NormalizedRect>>(kPalmRectsTag)];
hand_gesture_recognition_output.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
return graph.GetConfig();
}
@ -279,7 +294,17 @@ class GestureRecognizerGraph : public core::ModelTaskGraph {
/*handedness=*/handedness,
/*hand_landmarks=*/hand_landmarks,
/*hand_world_landmarks=*/hand_world_landmarks,
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)]};
/*hand_rects_next_frame =*/
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
kRectNextFrameTag)],
/*palm_rects =*/
hand_landmarker_graph[Output<std::vector<NormalizedRect>>(
kPalmRectsTag)],
/*palm_detections =*/
hand_landmarker_graph[Output<std::vector<Detection>>(
kPalmDetectionsTag)],
/*image=*/hand_landmarker_graph[Output<Image>(kImageTag)],
};
}
};

View File

@ -150,9 +150,9 @@ void ConfigureRectTransformationCalculator(
// Inputs:
// IMAGE - Image
// Image to perform detection on.
// NORM_RECT - NormalizedRect
// Describes image rotation and region of image to perform detection
// on.
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform detection on. If
// not provided, whole image is used for hand detection.
//
// Outputs:
// PALM_DETECTIONS - std::vector<Detection>
@ -197,11 +197,12 @@ class HandDetectorGraph : public core::ModelTaskGraph {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<HandDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto hand_detection_outs,
ASSIGN_OR_RETURN(
auto hand_detection_outs,
BuildHandDetectionSubgraph(
sc->Options<HandDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
sc->Options<HandDetectorGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_detection_outs.palm_detections >>
graph[Output<std::vector<Detection>>(kPalmDetectionsTag)];
hand_detection_outs.hand_rects >>

View File

@ -136,9 +136,10 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources,
// Inputs:
// IMAGE - Image
// Image to perform hand landmarks detection on.
// NORM_RECT - NormalizedRect
// NORM_RECT - NormalizedRect @Optional
// Describes image rotation and region of image to perform landmarks
// detection on.
// detection on. If not provided, whole image is used for hand landmarks
// detection.
//
// Outputs:
// LANDMARKS: - std::vector<NormalizedLandmarkList>
@ -218,11 +219,12 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable()));
}
ASSIGN_OR_RETURN(auto hand_landmarker_outputs,
ASSIGN_OR_RETURN(
auto hand_landmarker_outputs,
BuildHandLandmarkerGraph(
sc->Options<HandLandmarkerGraphOptions>(),
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kNormRectTag)], graph));
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
hand_landmarker_outputs.landmark_lists >>
graph[Output<std::vector<NormalizedLandmarkList>>(kLandmarksTag)];
hand_landmarker_outputs.world_landmark_lists >>

View File

@ -243,11 +243,12 @@ class SingleHandLandmarksDetectorGraph : public core::ModelTaskGraph {
const auto* model_resources,
CreateModelResources<HandLandmarksDetectorGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(auto hand_landmark_detection_outs,
ASSIGN_OR_RETURN(
auto hand_landmark_detection_outs,
BuildSingleHandLandmarksDetectorGraph(
sc->Options<HandLandmarksDetectorGraphOptions>(),
*model_resources, graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>(kHandRectTag)], graph));
sc->Options<HandLandmarksDetectorGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kHandRectTag)], graph));
hand_landmark_detection_outs.hand_landmarks >>
graph[Output<NormalizedLandmarkList>(kLandmarksTag)];
hand_landmark_detection_outs.world_hand_landmarks >>

View File

@ -257,10 +257,12 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) {
SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity));
}
TEST_F(ImageModeTest, SucceedsWithRotation) {
// TODO: fix this unit test after image segmenter handled post
// processing correctly with rotated image.
TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) {
MP_ASSERT_OK_AND_ASSIGN(
Image image, DecodeImageFromFile(
JoinPath("./", kTestDataDirectory, "cat_rotated.jpg")));
Image image,
DecodeImageFromFile(JoinPath("./", kTestDataDirectory, "cat.jpg")));
auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata);
@ -271,7 +273,8 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
ImageSegmenter::Create(std::move(options)));
ImageProcessingOptions image_processing_options;
image_processing_options.rotation_degrees = -90;
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image));
MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks,
segmenter->Segment(image, image_processing_options));
EXPECT_EQ(confidence_masks.size(), 21);
cv::Mat expected_mask =

View File

@ -74,7 +74,6 @@ cc_library(
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/processors:image_preprocessing_graph",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",

View File

@ -34,7 +34,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/acceleration.pb.h"
@ -69,7 +68,7 @@ using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>;
using ObjectDetectorOptionsProto =
object_detector::proto::ObjectDetectorOptions;
using TensorsSource =
mediapipe::tasks::SourceOrNodeOutput<std::vector<mediapipe::Tensor>>;
mediapipe::api2::builder::Source<std::vector<mediapipe::Tensor>>;
constexpr int kDefaultLocationsIndex = 0;
constexpr int kDefaultCategoriesIndex = 1;
@ -584,7 +583,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
auto post_processing_specs,
BuildPostProcessingSpecs(task_options, metadata_extractor));
// Calculators to perform score calibration, if specified in the metadata.
TensorsSource calibrated_tensors = {&inference, kTensorTag};
TensorsSource calibrated_tensors =
inference.Out(kTensorTag).Cast<std::vector<Tensor>>();
if (post_processing_specs.score_calibration_options.has_value()) {
// Split tensors.
auto* split_tensor_vector_node =
@ -623,7 +623,8 @@ class ObjectDetectorGraph : public core::ModelTaskGraph {
concatenate_tensor_vector_node->In(i);
}
}
calibrated_tensors = {concatenate_tensor_vector_node, 0};
calibrated_tensors =
concatenate_tensor_vector_node->Out(0).Cast<std::vector<Tensor>>();
}
// Calculator to convert output tensors to a detection proto vector.
// Connects TensorsToDetectionsCalculator's input stream to the output

View File

@ -26,16 +26,16 @@ NS_SWIFT_NAME(Embedding)
@interface MPPEmbedding : NSObject
/**
* @brief The Floating-point embedding.
* @brief The embedding represented as an `NSArray` of `Float` values.
* Empty if the embedder was configured to perform scalar quantization.
*/
@property(nonatomic, readonly, nullable) NSArray<NSNumber *> *floatEmbedding;
/**
* @brief The Quantized embedding.
* @brief The embedding represented as an `NSArray` of `UInt8` values.
* Empty if the embedder was not configured to perform scalar quantization.
*/
@property(nonatomic, readonly, nullable) NSData *quantizedEmbedding;
@property(nonatomic, readonly, nullable) NSArray<NSNumber *> *quantizedEmbedding;
/** The index of the embedder head these entries refer to. This is useful for multi-head models. */
@property(nonatomic, readonly) NSInteger headIndex;
@ -56,7 +56,7 @@ NS_SWIFT_NAME(Embedding)
* embedding, head index and head name.
*/
- (instancetype)initWithFloatEmbedding:(nullable NSArray<NSNumber *> *)floatEmbedding
quantizedEmbedding:(nullable NSData *)quantizedEmbedding
quantizedEmbedding:(nullable NSArray<NSNumber *> *)quantizedEmbedding
headIndex:(NSInteger)headIndex
headName:(nullable NSString *)headName NS_DESIGNATED_INITIALIZER;

View File

@ -17,7 +17,7 @@
@implementation MPPEmbedding
- (instancetype)initWithFloatEmbedding:(nullable NSArray<NSNumber *> *)floatEmbedding
quantizedEmbedding:(nullable NSData *)quantizedEmbedding
quantizedEmbedding:(nullable NSArray<NSNumber *> *)quantizedEmbedding
headIndex:(NSInteger)headIndex
headName:(nullable NSString *)headName {
self = [super init];

View File

@ -21,7 +21,7 @@ NS_ASSUME_NONNULL_BEGIN
/**
* Options for setting up a `MPPTextEmbedder`.
*/
NS_SWIFT_NAME(TextEmbedderptions)
NS_SWIFT_NAME(TextEmbedderOptions)
@interface MPPTextEmbedderOptions : MPPTaskOptions <NSCopying>
/**

View File

@ -44,6 +44,7 @@ cc_binary(
"//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph",
"//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph",
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
@ -176,6 +177,30 @@ android_library(
],
)
android_library(
name = "imagesegmenter",
srcs = [
"imagesegmenter/ImageSegmenter.java",
"imagesegmenter/ImageSegmenterResult.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = "imagesegmenter/AndroidManifest.xml",
deps = [
":core",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "imageembedder",
srcs = [

View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.imagesegmenter">
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
</manifest>

View File

@ -0,0 +1,462 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imagesegmenter;
import android.content.Context;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferImageBuilder;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto;
import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs image segmentation on images.
*
* <p>Note that, unlike other vision tasks, the output of ImageSegmenter is provided through a
* user-defined callback function even for the synchronous API. This makes it possible for
* ImageSegmenter to return the output masks without any copy. {@link ResultListener} must be set in
* the {@link ImageSegmenterOptions} for all {@link RunningMode}.
*
* <p>The API expects a TFLite model with,<a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
*
* <ul>
* <li>Input image {@link MPImage}
* <ul>
* <li>The image that image segmenter runs on.
* </ul>
* <li>Output ImageSegmenterResult {@link ImageSgmenterResult}
* <ul>
* <li>An ImageSegmenterResult containing segmented masks.
* </ul>
* </ul>
*/
public final class ImageSegmenter extends BaseVisionTaskApi {
private static final String TAG = ImageSegmenter.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(
Arrays.asList(
"GROUPED_SEGMENTATION:segmented_mask_out",
"IMAGE:image_out",
"SEGMENTATION:0:segmentation"));
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph";
/**
* Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}.
*
* @param context an Android {@link Context}.
* @param segmenterOptions an {@link ImageSegmenterOptions} instance.
* @throws MediaPipeException if there is an error during {@link ImageSegmenter} creation.
*/
public static ImageSegmenter createFromOptions(
Context context, ImageSegmenterOptions segmenterOptions) {
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ImageSegmenterResult, MPImage>() {
@Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException {
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
return ImageSegmenterResult.create(
new ArrayList<>(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
}
List<MPImage> segmentedMasks = new ArrayList<>();
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
int imageFormat =
segmenterOptions.outputType() == ImageSegmenterOptions.OutputType.CONFIDENCE_MASK
? MPImage.IMAGE_FORMAT_VEC32F1
: MPImage.IMAGE_FORMAT_ALPHA;
int imageListSize =
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
if (!PacketGetter.getImageList(
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), buffersArray, false)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting segmented masks. It usually results from incorrect"
+ " options of unsupported OutputType of given model.");
}
for (ByteBuffer buffer : buffersArray) {
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
segmentedMasks.add(builder.build());
}
return ImageSegmenterResult.create(
segmentedMasks,
BaseVisionTaskApi.generateResultTimestampMs(
segmenterOptions.runningMode(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
}
@Override
public MPImage convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
handler.setResultListener(segmenterOptions.resultListener());
segmenterOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<ImageSegmenterOptions>builder()
.setTaskName(ImageSegmenter.class.getSimpleName())
.setTaskRunningModeName(segmenterOptions.runningMode().name())
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(segmenterOptions)
.setEnableFlowLimiting(segmenterOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
return new ImageSegmenter(runner, segmenterOptions.runningMode());
}
/**
* Constructor to initialize an {@link ImageSegmenter} from a {@link TaskRunner} and a {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private ImageSegmenter(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME);
}
/**
* Performs image segmentation on the provided single image with default image processing options,
* i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
* {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO update java
* doc for input image format.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public void segment(MPImage image) {
segment(image, ImageProcessingOptions.builder().build());
}
/**
* Performs image segmentation on the provided single image, and the results will be available via
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
* when the {@link ImageSegmenter} is created with {@link RunningMode.IMAGE}. TODO
* update java doc for input image format.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public void segment(MPImage image, ImageProcessingOptions imageProcessingOptions) {
validateImageProcessingOptions(imageProcessingOptions);
ImageSegmenterResult unused =
(ImageSegmenterResult) processImageData(image, imageProcessingOptions);
}
/**
* Performs image segmentation on the provided video frame with default image processing options,
* i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
* {@link HandLandmarker} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void segmentForVideo(MPImage image, long timestampMs) {
segmentForVideo(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Performs image segmentation on the provided video frame, and the results will be available via
* the {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method
* when the {@link ImageSegmenter} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link HandLandmarker} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public void segmentForVideo(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
ImageSegmenterResult unused =
(ImageSegmenterResult) processVideoData(image, imageProcessingOptions, timestampMs);
}
/**
* Sends live image data to perform hand landmarks detection with default image processing
* options, i.e. without any rotation applied, and the results will be available via the {@link
* ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when the
* {@link ImageSegmenter } is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the image segmenter. The input timestamps must be monotonically increasing.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void segmentAsync(MPImage image, long timestampMs) {
segmentAsync(image, ImageProcessingOptions.builder().build(), timestampMs);
}
/**
* Sends live image data to perform image segmentation, and the results will be available via the
* {@link ResultListener} provided in the {@link ImageSegmenterOptions}. Only use this method when
* the {@link ImageSegmenter} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the image segmenter. The input timestamps must be monotonically increasing.
*
* <p>{@link ImageSegmenter} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
* input image before running inference. Note that region-of-interest is <b>not</b> supported
* by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in
* this method throwing an IllegalArgumentException.
* @param timestampMs the input timestamp (in milliseconds).
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
* region-of-interest.
* @throws MediaPipeException if there is an internal error.
*/
public void segmentAsync(
MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) {
validateImageProcessingOptions(imageProcessingOptions);
sendLiveStreamData(image, imageProcessingOptions, timestampMs);
}
/** Options for setting up an {@link ImageSegmenter}. */
@AutoValue
public abstract static class ImageSegmenterOptions extends TaskOptions {
/** Builder for {@link ImageSegmenterOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the image segmenter task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the running mode for the image segmenter task. Default to the image mode. Image
* segmenter has three modes:
*
* <ul>
* <li>IMAGE: The mode for segmenting image on single image inputs.
* <li>VIDEO: The mode for segmenting image on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for for segmenting image on a live stream of input data, such
* as from camera. In this mode, {@code setResultListener} must be called to set up a
* listener to receive the recognition results asynchronously.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode value);
/**
* The locale to use for display names specified through the TFLite Model Metadata, if any.
* Defaults to English.
*/
public abstract Builder setDisplayNamesLocale(String value);
/** The output type from image segmenter. */
public abstract Builder setOutputType(OutputType value);
/**
* Sets the {@link ResultListener} to receive the segmentation results when the graph pipeline
* is done processing an image.
*/
public abstract Builder setResultListener(
ResultListener<ImageSegmenterResult, MPImage> value);
/** Sets an optional {@link ErrorListener}}. */
public abstract Builder setErrorListener(ErrorListener value);
abstract ImageSegmenterOptions autoBuild();
/**
* Validates and builds the {@link ImageSegmenterOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the image segmenter is
* in the live stream mode.
*/
public final ImageSegmenterOptions build() {
ImageSegmenterOptions options = autoBuild();
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract String displayNamesLocale();
abstract OutputType outputType();
abstract ResultListener<ImageSegmenterResult, MPImage> resultListener();
abstract Optional<ErrorListener> errorListener();
/** The output type of segmentation results. */
public enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK
}
public static Builder builder() {
return new AutoValue_ImageSegmenter_ImageSegmenterOptions.Builder()
.setRunningMode(RunningMode.IMAGE)
.setDisplayNamesLocale("en")
.setOutputType(OutputType.CATEGORY_MASK)
.setResultListener((result, image) -> {});
}
/**
* Converts an {@link ImageSegmenterOptions} to a {@link CalculatorOptions} protobuf message.
*/
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder taskOptionsBuilder =
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder()
.setBaseOptions(
BaseOptionsProto.BaseOptions.newBuilder()
.setUseStreamMode(runningMode() != RunningMode.IMAGE)
.mergeFrom(convertBaseOptionsToProto(baseOptions()))
.build())
.setDisplayNamesLocale(displayNamesLocale());
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
SegmenterOptionsProto.SegmenterOptions.newBuilder();
if (outputType() == OutputType.CONFIDENCE_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
} else if (outputType() == OutputType.CATEGORY_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
// TODO: remove this once activation is handled in metadata and grpah level.
segmenterOptionsBuilder.setActivation(
SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX);
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder()
.setExtension(
ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext,
taskOptionsBuilder.build())
.build();
}
}
/**
* Validates that the provided {@link ImageProcessingOptions} doesn't contain a
* region-of-interest.
*/
private static void validateImageProcessingOptions(
ImageProcessingOptions imageProcessingOptions) {
if (imageProcessingOptions.regionOfInterest().isPresent()) {
throw new IllegalArgumentException("ImageSegmenter doesn't support region-of-interest.");
}
}
}

View File

@ -0,0 +1,45 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imagesegmenter;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.TaskResult;
import java.util.Collections;
import java.util.List;
/** Represents the segmentation results generated by {@link ImageSegmenter}. */
@AutoValue
public abstract class ImageSegmenterResult implements TaskResult {
/**
* Creates an {@link ImageSegmenterResult} instance from a list of segmentation MPImage.
*
* @param segmentations a {@link List} of MPImage representing the segmented masks. If OutputType
* is CATEGORY_MASK, the masks will be in IMAGE_FORMAT_ALPHA format. If OutputType is
* CONFIDENCE_MASK, the masks will be in IMAGE_FORMAT_ALPHA format.
* @param timestampMs a timestamp for this result.
*/
// TODO: consolidate output formats across platforms.
static ImageSegmenterResult create(List<MPImage> segmentations, long timestampMs) {
return new AutoValue_ImageSegmenterResult(
Collections.unmodifiableList(segmentations), timestampMs);
}
public abstract List<MPImage> segmentations();
@Override
public abstract long timestampMs();
}

View File

@ -95,4 +95,24 @@ public class TextEmbedderTest {
result1.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.999937);
}
@Test
public void classify_succeedsWithBertAndDifferentThemes() throws Exception {
TextEmbedder textEmbedder =
TextEmbedder.createFromFile(ApplicationProvider.getApplicationContext(), BERT_MODEL_FILE);
TextEmbedderResult result0 =
textEmbedder.embed(
"When you go to this restaurant, they hold the pancake upside-down before they hand "
+ "it to you. It's a great gimmick.");
TextEmbedderResult result1 =
textEmbedder.embed("Let\'s make a plan to steal the declaration of independence.'");
// Check cosine similarity.
double similarity =
TextEmbedder.cosineSimilarity(
result0.embeddingResult().embeddings().get(0),
result1.embeddingResult().embeddings().get(0));
assertThat(similarity).isWithin(DOUBLE_DIFF_TOLERANCE).of(0.3477488707202946);
}
}

View File

@ -0,0 +1,24 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.google.mediapipe.tasks.vision.imagesegmentertest"
android:versionCode="1"
android:versionName="1.0" >
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE"/>
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE"/>
<uses-sdk android:minSdkVersion="24"
android:targetSdkVersion="30" />
<application
android:label="imagesegmentertest"
android:name="android.support.multidex.MultiDexApplication"
android:taskAffinity="">
<uses-library android:name="android.test.runner" />
</application>
<instrumentation
android:name="com.google.android.apps.common.testing.testrunner.GoogleInstrumentationTestRunner"
android:targetPackage="com.google.mediapipe.tasks.vision.imagesegmentertest" />
</manifest>

View File

@ -0,0 +1,19 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS

View File

@ -0,0 +1,427 @@
// Copyright 2023 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.imagesegmenter;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Color;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapExtractor;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.ByteBufferExtractor;
import com.google.mediapipe.framework.image.MPImage;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegmenterOptions;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.util.List;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link ImageSegmenter}. */
@RunWith(Suite.class)
@SuiteClasses({ImageSegmenterTest.General.class, ImageSegmenterTest.RunningModeTest.class})
public class ImageSegmenterTest {
private static final String DEEPLAB_MODEL_FILE = "deeplabv3.tflite";
private static final String SELFIE_128x128_MODEL_FILE = "selfie_segm_128_128_3.tflite";
private static final String SELFIE_144x256_MODEL_FILE = "selfie_segm_144_256_3.tflite";
private static final String CAT_IMAGE = "cat.jpg";
private static final float GOLDEN_MASK_SIMILARITY = 0.96f;
private static final int MAGNIFICATION_FACTOR = 10;
@RunWith(AndroidJUnit4.class)
public static final class General extends ImageSegmenterTest {
@Test
public void segment_successWithCategoryMask() throws Exception {
final String inputImageName = "segmentation_input_rotation0.jpg";
final String goldenImageName = "segmentation_golden_rotation0.png";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CATEGORY_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(1);
MPImage actualMaskBuffer = actualResult.segmentations().get(0);
verifyCategoryMask(
actualMaskBuffer,
expectedMaskBuffer,
GOLDEN_MASK_SIMILARITY,
MAGNIFICATION_FACTOR);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
}
@Test
public void segment_successWithConfidenceMask() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
}
@Test
public void segment_successWith128x128Segmentation() throws Exception {
final String inputImageName = "mozart_square.jpg";
final String goldenImageName = "selfie_segm_128_128_3_expected_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(
BaseOptions.builder().setModelAssetPath(SELFIE_128x128_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(2);
// Selfie category index 1.
MPImage actualMaskBuffer = actualResult.segmentations().get(1);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
}
// TODO: enable this unit test once activation option is supported in metadata.
// @Test
// public void segment_successWith144x256Segmentation() throws Exception {
// final String inputImageName = "mozart_square.jpg";
// final String goldenImageName = "selfie_segm_144_256_3_expected_mask.jpg";
// MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
// ImageSegmenterOptions options =
// ImageSegmenterOptions.builder()
// .setBaseOptions(
// BaseOptions.builder().setModelAssetPath(SELFIE_144x256_MODEL_FILE).build())
// .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
// .setActivation(ImageSegmenterOptions.Activation.NONE)
// .setResultListener(
// (actualResult, inputImage) -> {
// List<MPImage> segmentations = actualResult.segmentations();
// assertThat(segmentations.size()).isEqualTo(1);
// MPImage actualMaskBuffer = actualResult.segmentations().get(0);
// verifyConfidenceMask(
// actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
// })
// .build();
// ImageSegmenter imageSegmenter =
// ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(),
// options);
// imageSegmenter.segment(getImageFromAsset(inputImageName));
// }
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends ImageSegmenterTest {
@Test
public void segment_failsWithCallingWrongApiInImageMode() throws Exception {
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
imageSegmenter.segmentForVideo(
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void segment_failsWithCallingWrongApiInVideoMode() throws Exception {
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
imageSegmenter.segmentAsync(getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void segment_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((result, inputImage) -> {})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class, () -> imageSegmenter.segment(getImageFromAsset(CAT_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() ->
imageSegmenter.segmentForVideo(
getImageFromAsset(CAT_IMAGE), /* timestampsMs= */ 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@Test
public void segment_successWithImageMode() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.IMAGE)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
imageSegmenter.segment(getImageFromAsset(inputImageName));
}
@Test
public void segment_successWithVideoMode() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.VIDEO)
.setResultListener(
(actualResult, inputImage) -> {
List<MPImage> segmentations = actualResult.segmentations();
assertThat(segmentations.size()).isEqualTo(21);
// Cat category index 8.
MPImage actualMaskBuffer = actualResult.segmentations().get(8);
verifyConfidenceMask(
actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY);
})
.build();
ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
imageSegmenter.segmentForVideo(getImageFromAsset(inputImageName), /* timestampsMs= */ i);
}
}
@Test
public void segment_successWithLiveStreamMode() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage image = getImageFromAsset(inputImageName);
MPImage expectedResult = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(segmenterResult, inputImage) -> {
verifyConfidenceMask(
segmenterResult.segmentations().get(8),
expectedResult,
GOLDEN_MASK_SIMILARITY);
})
.build();
try (ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
imageSegmenter.segmentAsync(image, /* timestampsMs= */ i);
}
}
}
@Test
public void segment_failsWithOutOfOrderInputTimestamps() throws Exception {
final String inputImageName = "cat.jpg";
final String goldenImageName = "cat_mask.jpg";
MPImage image = getImageFromAsset(inputImageName);
MPImage expectedResult = getImageFromAsset(goldenImageName);
ImageSegmenterOptions options =
ImageSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK)
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(segmenterResult, inputImage) -> {
verifyConfidenceMask(
segmenterResult.segmentations().get(8),
expectedResult,
GOLDEN_MASK_SIMILARITY);
})
.build();
try (ImageSegmenter imageSegmenter =
ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
imageSegmenter.segmentAsync(image, /* timestampsMs= */ 1);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> imageSegmenter.segmentAsync(image, /* timestampsMs= */ 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
}
}
}
private static void verifyCategoryMask(
MPImage actualMask, MPImage goldenMask, float similarityThreshold, int magnificationFactor) {
assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth());
assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight());
ByteBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask);
Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask);
int consistentPixels = 0;
final int numPixels = actualMask.getWidth() * actualMask.getHeight();
actualMaskBuffer.rewind();
for (int y = 0; y < actualMask.getHeight(); y++) {
for (int x = 0; x < actualMask.getWidth(); x++) {
// RGB values are the same in the golden mask image.
consistentPixels +=
actualMaskBuffer.get() * magnificationFactor
== Color.red(goldenMaskBitmap.getPixel(x, y))
? 1
: 0;
}
}
assertThat((float) consistentPixels / numPixels).isGreaterThan(similarityThreshold);
}
private static void verifyConfidenceMask(
MPImage actualMask, MPImage goldenMask, float similarityThreshold) {
assertThat(actualMask.getWidth()).isEqualTo(goldenMask.getWidth());
assertThat(actualMask.getHeight()).isEqualTo(goldenMask.getHeight());
FloatBuffer actualMaskBuffer = ByteBufferExtractor.extract(actualMask).asFloatBuffer();
Bitmap goldenMaskBitmap = BitmapExtractor.extract(goldenMask);
FloatBuffer goldenMaskBuffer = getByteBufferFromBitmap(goldenMaskBitmap).asFloatBuffer();
assertThat(
calculateSoftIOU(
actualMaskBuffer, goldenMaskBuffer, actualMask.getWidth() * actualMask.getHeight()))
.isGreaterThan((double) similarityThreshold);
}
private static MPImage getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
private static ByteBuffer getByteBufferFromBitmap(Bitmap bitmap) {
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(bitmap.getWidth() * bitmap.getHeight() * 4);
for (int y = 0; y < bitmap.getHeight(); y++) {
for (int x = 0; x < bitmap.getWidth(); x++) {
byteBuffer.putFloat((float) Color.red(bitmap.getPixel(x, y)) / 255.f);
}
}
byteBuffer.rewind();
return byteBuffer;
}
private static double calculateSum(FloatBuffer m) {
m.rewind();
double sum = 0;
while (m.hasRemaining()) {
sum += m.get();
}
m.rewind();
return sum;
}
private static FloatBuffer multiply(FloatBuffer m1, FloatBuffer m2, int bufferSize) {
m1.rewind();
m2.rewind();
FloatBuffer buffer = FloatBuffer.allocate(bufferSize);
while (m1.hasRemaining()) {
buffer.put(m1.get() * m2.get());
}
m1.rewind();
m2.rewind();
buffer.rewind();
return buffer;
}
private static double calculateSoftIOU(FloatBuffer m1, FloatBuffer m2, int bufferSize) {
double intersectionSum = calculateSum(multiply(m1, m2, bufferSize));
double m1m1 = calculateSum(multiply(m1, m1.duplicate(), bufferSize));
double m2m2 = calculateSum(multiply(m2, m2.duplicate(), bufferSize));
double unionSum = m1m1 + m2m2 - intersectionSum;
return unionSum > 0.0 ? intersectionSum / unionSum : 0.0;
}
}

View File

@ -192,6 +192,36 @@ class TextEmbedderTest(parameterized.TestCase):
self._check_embedding_value(result1, expected_result1_value)
self._check_cosine_similarity(result0, result1, expected_similarity)
def test_embed_with_mobile_bert_and_different_themes(self):
# Creates embedder.
model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _BERT_MODEL_FILE)
)
base_options = _BaseOptions(model_asset_path=model_path)
options = _TextEmbedderOptions(base_options=base_options)
embedder = _TextEmbedder.create_from_options(options)
# Extracts both embeddings.
text0 = (
'When you go to this restaurant, they hold the pancake upside-down '
"before they hand it to you. It's a great gimmick."
)
result0 = embedder.embed(text0)
text1 = "Let's make a plan to steal the declaration of independence."
result1 = embedder.embed(text1)
similarity = _TextEmbedder.cosine_similarity(
result0.embeddings[0], result1.embeddings[0]
)
# TODO: The similarity should likely be lower
self.assertAlmostEqual(similarity, 0.980880, delta=_SIMILARITY_TOLERANCE)
# Closes the embedder explicitly when the embedder is not used in
# a context.
embedder.close()
if __name__ == '__main__':
absltest.main()

View File

@ -33,6 +33,7 @@ export declare interface Embedding {
* perform scalar quantization.
*/
quantizedEmbedding?: Uint8Array;
/**
* The index of the classifier head these categories refer to. This is
* useful for multi-head models.

View File

@ -70,12 +70,12 @@ describe('computeCosineSimilarity', () => {
it('succeeds with quantized embeddings', () => {
const u: Embedding = {
quantizedEmbedding: new Uint8Array([255, 128, 128, 128]),
quantizedEmbedding: new Uint8Array([127, 0, 0, 0]),
headIndex: 0,
headName: ''
};
const v: Embedding = {
quantizedEmbedding: new Uint8Array([0, 128, 128, 128]),
quantizedEmbedding: new Uint8Array([128, 0, 0, 0]),
headIndex: 0,
headName: ''
};

View File

@ -38,7 +38,7 @@ export function computeCosineSimilarity(u: Embedding, v: Embedding): number {
}
function convertToBytes(data: Uint8Array): number[] {
return Array.from(data, v => v - 128);
return Array.from(data, v => v > 127 ? v - 256 : v);
}
function compute(u: number[], v: number[]) {

View File

@ -23,6 +23,7 @@ VISION_LIBS = [
"//mediapipe/tasks/web/vision/hand_landmarker",
"//mediapipe/tasks/web/vision/image_classifier",
"//mediapipe/tasks/web/vision/image_embedder",
"//mediapipe/tasks/web/vision/image_segmenter",
"//mediapipe/tasks/web/vision/object_detector",
]

View File

@ -39,6 +39,23 @@ const classifications = imageClassifier.classify(image);
For more information, refer to the [Image Classification](https://developers.google.com/mediapipe/solutions/vision/image_classifier/web_js) documentation.
## Image Segmentation
The MediaPipe Image Segmenter lets you segment an image into categories.
```
const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
);
const imageSegmenter = await ImageSegmenter.createFromModelPath(vision,
"model.tflite"
);
const image = document.getElementById("image") as HTMLImageElement;
imageSegmenter.segment(image, (masks, width, height) => {
...
});
```
## Gesture Recognition
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real

View File

@ -0,0 +1,58 @@
# This contains the MediaPipe Image Segmenter Task.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library")
load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_ts_library(
name = "image_segmenter",
srcs = ["image_segmenter.ts"],
deps = [
":image_segmenter_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/tasks/cc/core/proto:base_options_jspb_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts",
],
)
mediapipe_ts_declaration(
name = "image_segmenter_types",
srcs = ["image_segmenter_options.d.ts"],
deps = [
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:vision_task_options",
],
)
mediapipe_ts_library(
name = "image_segmenter_test_lib",
testonly = True,
srcs = [
"image_segmenter_test.ts",
],
deps = [
":image_segmenter",
":image_segmenter_types",
"//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:task_runner_test_utils",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
],
)
jasmine_node_test(
name = "image_segmenter_test",
tags = ["nomsan"],
deps = [":image_segmenter_test_lib"],
)

View File

@ -0,0 +1,300 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {CalculatorOptions} from '../../../../framework/calculator_options_pb';
import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb';
import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb';
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url
import {ImageSegmenterOptions} from './image_segmenter_options';
export * from './image_segmenter_options';
export {ImageSource}; // Used in the public API
/**
* The ImageSegmenter returns the segmentation result as a Uint8Array (when
* the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
* for future usage.
*/
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
/**
* A callback that receives the computed masks from the image segmenter. The
* callback either receives a single element array with a category mask (as a
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type SegmentationMaskCallback =
(masks: SegmentationMask[], width: number, height: number) => void;
const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect';
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
// The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern
/** Performs image segmentation on images. */
export class ImageSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {};
private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto;
/**
* Initializes the Wasm runtime and creates a new image segmenter from the
* provided options.
* @param wasmFileset A configuration object that provides the location of
* the Wasm binary and its loader.
* @param imageSegmenterOptions The options for the Image Segmenter. Note
* that either a path to the model asset or a model buffer needs to be
* provided (via `baseOptions`).
*/
static createFromOptions(
wasmFileset: WasmFileset,
imageSegmenterOptions: ImageSegmenterOptions): Promise<ImageSegmenter> {
return VisionTaskRunner.createInstance(
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
imageSegmenterOptions);
}
/**
* Initializes the Wasm runtime and creates a new image segmenter based on
* the provided model asset buffer.
* @param wasmFileset A configuration object that provides the location of
* the Wasm binary and its loader.
* @param modelAssetBuffer A binary representation of the model.
*/
static createFromModelBuffer(
wasmFileset: WasmFileset,
modelAssetBuffer: Uint8Array): Promise<ImageSegmenter> {
return VisionTaskRunner.createInstance(
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetBuffer}});
}
/**
* Initializes the Wasm runtime and creates a new image segmenter based on
* the path to the model asset.
* @param wasmFileset A configuration object that provides the location of
* the Wasm binary and its loader.
* @param modelAssetPath The path to the model asset.
*/
static createFromModelPath(
wasmFileset: WasmFileset,
modelAssetPath: string): Promise<ImageSegmenter> {
return VisionTaskRunner.createInstance(
ImageSegmenter, /* initializeCanvas= */ true, wasmFileset,
{baseOptions: {modelAssetPath}});
}
/** @hideconstructor */
constructor(
wasmModule: WasmModule,
glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) {
super(
new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM,
NORM_RECT_STREAM, /* roiAllowed= */ false);
this.options = new ImageSegmenterGraphOptionsProto();
this.segmenterOptions = new SegmenterOptionsProto();
this.options.setSegmenterOptions(this.segmenterOptions);
this.options.setBaseOptions(new BaseOptionsProto());
}
protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!;
}
protected override set baseOptions(proto: BaseOptionsProto) {
this.options.setBaseOptions(proto);
}
/**
* Sets new options for the image segmenter.
*
* Calling `setOptions()` with a subset of options only affects those
* options. You can reset an option back to its default value by
* explicitly setting it to `undefined`.
*
* @param options The options for the image segmenter.
*/
override setOptions(options: ImageSegmenterOptions): Promise<void> {
// Note that we have to support both JSPB and ProtobufJS, hence we
// have to expliclity clear the values instead of setting them to
// `undefined`.
if (options.displayNamesLocale !== undefined) {
this.options.setDisplayNamesLocale(options.displayNamesLocale);
} else if ('displayNamesLocale' in options) { // Check for undefined
this.options.clearDisplayNamesLocale();
}
if (options.outputType === 'CONFIDENCE_MASK') {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
} else {
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
}
return super.applyOptions(options);
}
/**
* Performs image segmentation on the provided single image and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `image`.
*
* @param image An image to process.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segment(image: ImageSource, callback: SegmentationMaskCallback): void;
/**
* Performs image segmentation on the provided single image and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `image`.
*
* @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segment(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: SegmentationMaskCallback): void;
segment(
image: ImageSource,
imageProcessingOptionsOrCallback: ImageProcessingOptions|
SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback :
{};
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback :
callback!;
this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {};
}
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, timestamp: number,
callback: SegmentationMaskCallback): void;
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: SegmentationMaskCallback): void;
segmentForVideo(
videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions :
{};
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
this.userCallback = () => {};
}
/** Updates the MediaPipe graph configuration. */
protected override refreshGraph(): void {
const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM);
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension(
ImageSegmenterGraphOptionsProto.ext, this.options);
const segmenterNode = new CalculatorGraphConfig.Node();
segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH);
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
segmenterNode.addOutputStream(
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
segmenterNode.setOptions(calculatorOptions);
graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageVectorListener(
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
if (masks.length === 0) {
this.userCallback([], 0, 0);
} else {
this.userCallback(
masks.map(m => m.data), masks[0].width, masks[0].height);
}
this.setLatestOutputTimestamp(timestamp);
});
const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
}
}

View File

@ -0,0 +1,41 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options';
/** Options to configure the MediaPipe Image Segmenter Task */
export interface ImageSegmenterOptions extends VisionTaskOptions {
/**
* The locale to use for display names specified through the TFLite Model
* Metadata, if any. Defaults to English.
*/
displayNamesLocale?: string|undefined;
/**
* The output type of segmentation results.
*
* The two supported modes are:
* - Category Mask: Gives a single output mask where each pixel represents
* the class which the pixel in the original image was
* predicted to belong to.
* - Confidence Mask: Gives a list of output masks (one for each class). For
* each mask, the pixel represents the prediction
* confidence, usually in the [0.0, 0.1] range.
*
* Defaults to `CATEGORY_MASK`.
*/
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
}

View File

@ -0,0 +1,215 @@
/**
* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import 'jasmine';
// Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {ImageSegmenter} from './image_segmenter';
import {ImageSegmenterOptions} from './image_segmenter_options';
class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
calculatorName = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
attachListenerSpies: jasmine.Spy[] = [];
graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule;
imageVectorListener:
((images: WasmImage[], timestamp: number) => void)|undefined;
constructor() {
super(createSpyWasmModule(), /* glCanvas= */ null);
this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] =
spyOn(this.graphRunner, 'attachImageVectorListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('segmented_masks');
this.imageVectorListener = listener;
});
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
});
spyOn(this.graphRunner, 'addGpuBufferAsImageToStream');
}
}
describe('ImageSegmenter', () => {
let imageSegmenter: ImageSegmenterFake;
beforeEach(async () => {
addJasmineCustomFloatEqualityTester();
imageSegmenter = new ImageSegmenterFake();
await imageSegmenter.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
});
it('initializes graph', async () => {
verifyGraph(imageSegmenter);
verifyListenersRegistered(imageSegmenter);
});
it('reloads graph when settings are changed', async () => {
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
verifyListenersRegistered(imageSegmenter);
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
verifyListenersRegistered(imageSegmenter);
});
it('can use custom models', async () => {
const newModel = new Uint8Array([0, 1, 2, 3, 4]);
const newModelBase64 = Buffer.from(newModel).toString('base64');
await imageSegmenter.setOptions({
baseOptions: {
modelAssetBuffer: newModel,
}
});
verifyGraph(
imageSegmenter,
/* expectedCalculatorOptions= */ undefined,
/* expectedBaseOptions= */
[
'modelAsset', {
fileContent: newModelBase64,
fileName: undefined,
fileDescriptorMeta: undefined,
filePointerMeta: undefined
}
]);
});
it('merges options', async () => {
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
});
describe('setOptions()', () => {
interface TestCase {
optionName: keyof ImageSegmenterOptions;
fieldPath: string[];
userValue: unknown;
graphValue: unknown;
defaultValue: unknown;
}
const testCases: TestCase[] = [
{
optionName: 'displayNamesLocale',
fieldPath: ['displayNamesLocale'],
userValue: 'en',
graphValue: 'en',
defaultValue: 'en'
},
{
optionName: 'outputType',
fieldPath: ['segmenterOptions', 'outputType'],
userValue: 'CONFIDENCE_MASK',
graphValue: 2,
defaultValue: 1
},
];
for (const testCase of testCases) {
it(`can set ${testCase.optionName}`, async () => {
await imageSegmenter.setOptions(
{[testCase.optionName]: testCase.userValue});
verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
});
it(`can clear ${testCase.optionName}`, async () => {
await imageSegmenter.setOptions(
{[testCase.optionName]: testCase.userValue});
verifyGraph(imageSegmenter, [testCase.fieldPath, testCase.graphValue]);
await imageSegmenter.setOptions({[testCase.optionName]: undefined});
verifyGraph(
imageSegmenter, [testCase.fieldPath, testCase.defaultValue]);
});
}
});
it('doesn\'t support region of interest', () => {
expect(() => {
imageSegmenter.segment(
{} as HTMLImageElement,
{regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {});
}).toThrowError('This task doesn\'t support region-of-interest.');
});
it('supports category masks', (done) => {
const mask = new Uint8Array([1, 2, 3, 4]);
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter);
imageSegmenter.imageVectorListener!(
[
{data: mask, width: 2, height: 2},
],
/* timestamp= */ 1337);
});
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(masks).toHaveSize(1);
expect(masks[0]).toEqual(mask);
expect(width).toEqual(2);
expect(height).toEqual(2);
done();
});
});
it('supports confidence masks', async () => {
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter);
imageSegmenter.imageVectorListener!(
[
{data: mask1, width: 2, height: 2},
{data: mask2, width: 2, height: 2},
],
1337);
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(masks).toHaveSize(2);
expect(masks[0]).toEqual(mask1);
expect(masks[1]).toEqual(mask2);
expect(width).toEqual(2);
expect(height).toEqual(2);
resolve();
});
});
});
});

View File

@ -19,6 +19,7 @@ import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vis
import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier';
import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder';
import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter';
import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector';
// Declare the variables locally so that Rollup in OSS includes them explicitly
@ -28,6 +29,7 @@ const GestureRecognizer = GestureRecognizerImpl;
const HandLandmarker = HandLandmarkerImpl;
const ImageClassifier = ImageClassifierImpl;
const ImageEmbedder = ImageEmbedderImpl;
const ImageSegmenter = ImageSegementerImpl;
const ObjectDetector = ObjectDetectorImpl;
export {
@ -36,5 +38,6 @@ export {
HandLandmarker,
ImageClassifier,
ImageEmbedder,
ImageSegmenter,
ObjectDetector
};

View File

@ -19,4 +19,5 @@ export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer';
export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker';
export * from '../../../tasks/web/vision/image_classifier/image_classifier';
export * from '../../../tasks/web/vision/image_embedder/image_embedder';
export * from '../../../tasks/web/vision/image_segmenter/image_segmenter';
export * from '../../../tasks/web/vision/object_detector/object_detector';

View File

@ -1,14 +0,0 @@
diff --git a/absl/time/internal/cctz/BUILD.bazel b/absl/time/internal/cctz/BUILD.bazel
index 9fceffe..e7f9d01 100644
--- a/absl/time/internal/cctz/BUILD.bazel
+++ b/absl/time/internal/cctz/BUILD.bazel
@@ -69,8 +69,5 @@ cc_library(
"include/cctz/zone_info_source.h",
],
linkopts = select({
- ":osx": [
- "-framework Foundation",
- ],
":ios": [
"-framework Foundation",
],

View File

@ -0,0 +1,13 @@
diff --git a/absl/types/compare.h b/absl/types/compare.h
index 19b076e..0201004 100644
--- a/absl/types/compare.h
+++ b/absl/types/compare.h
@@ -84,7 +84,7 @@ enum class ncmp : value_type { unordered = -127 };
// based on whether the feature is supported. Note: we can't use
// ABSL_INTERNAL_INLINE_CONSTEXPR here because the variables here are of
// incomplete types so they need to be defined after the types are complete.
-#ifdef __cpp_inline_variables
+#if defined(__cpp_inline_variables) && !(defined(_MSC_VER) && _MSC_VER <= 1916)
// A no-op expansion that can be followed by a semicolon at class level.
#define ABSL_COMPARE_INLINE_BASECLASS_DECL(name) static_assert(true, "")