Merge branch 'google:master' into pose-landmarker-python

This commit is contained in:
Kinar R 2023-04-19 10:21:23 +05:30 committed by GitHub
commit 39742b6641
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
98 changed files with 5096 additions and 974 deletions

View File

@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
} else if (packet_options.has_string_value()) { } else if (packet_options.has_string_value()) {
packet.Set<std::string>(); packet.Set<std::string>();
} else if (packet_options.has_uint64_value()) { } else if (packet_options.has_uint64_value()) {
packet.Set<uint64>(); packet.Set<uint64_t>();
} else if (packet_options.has_classification_list_value()) { } else if (packet_options.has_classification_list_value()) {
packet.Set<ClassificationList>(); packet.Set<ClassificationList>();
} else if (packet_options.has_landmark_list_value()) { } else if (packet_options.has_landmark_list_value()) {
@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
} else if (packet_options.has_string_value()) { } else if (packet_options.has_string_value()) {
packet.Set(MakePacket<std::string>(packet_options.string_value())); packet.Set(MakePacket<std::string>(packet_options.string_value()));
} else if (packet_options.has_uint64_value()) { } else if (packet_options.has_uint64_value()) {
packet.Set(MakePacket<uint64>(packet_options.uint64_value())); packet.Set(MakePacket<uint64_t>(packet_options.uint64_value()));
} else if (packet_options.has_classification_list_value()) { } else if (packet_options.has_classification_list_value()) {
packet.Set(MakePacket<ClassificationList>( packet.Set(MakePacket<ClassificationList>(
packet_options.classification_list_value())); packet_options.classification_list_value()));

View File

@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test {
} }
// Use this when ALLOW/DISALLOW input is provided as a side packet. // Use this when ALLOW/DISALLOW input is provided as a side packet.
void RunTimeStep(int64 timestamp, bool stream_payload) { void RunTimeStep(int64_t timestamp, bool stream_payload) {
runner_->MutableInputs()->Get("", 0).packets.push_back( runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(stream_payload).At(Timestamp(timestamp))); MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
} }
// Use this when ALLOW/DISALLOW input is provided as an input stream. // Use this when ALLOW/DISALLOW input is provided as an input stream.
void RunTimeStep(int64 timestamp, const std::string& control_tag, void RunTimeStep(int64_t timestamp, const std::string& control_tag,
bool control) { bool control) {
runner_->MutableInputs()->Get("", 0).packets.push_back( runner_->MutableInputs()->Get("", 0).packets.push_back(
MakePacket<bool>(true).At(Timestamp(timestamp))); MakePacket<bool>(true).At(Timestamp(timestamp)));
@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) {
} }
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false); RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) {
} }
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false); RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) {
output_stream: "test_output" output_stream: "test_output"
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false); RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
)"); )");
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true)); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false); RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
)"); )");
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false)); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false); RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
)"); )");
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false)); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false); RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
)"); )");
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true)); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, true); RunTimeStep(kTimestampValue0, true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, false); RunTimeStep(kTimestampValue1, false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) {
output_stream: "test_output" output_stream: "test_output"
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true); RunTimeStep(kTimestampValue0, "ALLOW", true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", false); RunTimeStep(kTimestampValue1, "ALLOW", false);
constexpr int64 kTimestampValue2 = 44; constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true); RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45; constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false); RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) {
output_stream: "test_output" output_stream: "test_output"
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true); RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false); RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44; constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", true); RunTimeStep(kTimestampValue2, "DISALLOW", true);
constexpr int64 kTimestampValue3 = 45; constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", false); RunTimeStep(kTimestampValue3, "DISALLOW", false);
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) {
output_stream: "STATE_CHANGE:state_changed" output_stream: "STATE_CHANGE:state_changed"
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", false); RunTimeStep(kTimestampValue0, "ALLOW", false);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "ALLOW", true); RunTimeStep(kTimestampValue1, "ALLOW", true);
constexpr int64 kTimestampValue2 = 44; constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "ALLOW", true); RunTimeStep(kTimestampValue2, "ALLOW", true);
constexpr int64 kTimestampValue3 = 45; constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "ALLOW", false); RunTimeStep(kTimestampValue3, "ALLOW", false);
const std::vector<Packet>& output = const std::vector<Packet>& output =
@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) {
output_stream: "STATE_CHANGE:state_changed" output_stream: "STATE_CHANGE:state_changed"
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", true); RunTimeStep(kTimestampValue0, "DISALLOW", true);
constexpr int64 kTimestampValue1 = 43; constexpr int64_t kTimestampValue1 = 43;
RunTimeStep(kTimestampValue1, "DISALLOW", false); RunTimeStep(kTimestampValue1, "DISALLOW", false);
constexpr int64 kTimestampValue2 = 44; constexpr int64_t kTimestampValue2 = 44;
RunTimeStep(kTimestampValue2, "DISALLOW", false); RunTimeStep(kTimestampValue2, "DISALLOW", false);
constexpr int64 kTimestampValue3 = 45; constexpr int64_t kTimestampValue3 = 45;
RunTimeStep(kTimestampValue3, "DISALLOW", true); RunTimeStep(kTimestampValue3, "DISALLOW", true);
const std::vector<Packet>& output = const std::vector<Packet>& output =
@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
output_stream: "STATE_CHANGE:state_changed" output_stream: "STATE_CHANGE:state_changed"
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "DISALLOW", false); RunTimeStep(kTimestampValue0, "DISALLOW", false);
const std::vector<Packet>& output = const std::vector<Packet>& output =
@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
output_stream: "STATE_CHANGE:state_changed" output_stream: "STATE_CHANGE:state_changed"
)"); )");
constexpr int64 kTimestampValue0 = 42; constexpr int64_t kTimestampValue0 = 42;
RunTimeStep(kTimestampValue0, "ALLOW", true); RunTimeStep(kTimestampValue0, "ALLOW", true);
const std::vector<Packet>& output = const std::vector<Packet>& output =

View File

@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator);
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>; using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
REGISTER_CALCULATOR(StringToUintCalculator); REGISTER_CALCULATOR(StringToUintCalculator);
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>; using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>;
REGISTER_CALCULATOR(StringToInt32Calculator); REGISTER_CALCULATOR(StringToInt32Calculator);
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>; using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>;
REGISTER_CALCULATOR(StringToUint32Calculator); REGISTER_CALCULATOR(StringToUint32Calculator);
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>; using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>;
REGISTER_CALCULATOR(StringToInt64Calculator); REGISTER_CALCULATOR(StringToInt64Calculator);
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>; using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>;
REGISTER_CALCULATOR(StringToUint64Calculator); REGISTER_CALCULATOR(StringToUint64Calculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(), const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
frame_ptr->Height(), frame_ptr->WidthStep(), frame_ptr->Height(), frame_ptr->WidthStep(),
const_cast<uint8_t*>(frame_ptr->PixelData()), const_cast<uint8_t*>(frame_ptr->PixelData()),
[](uint8* data){}); [](uint8_t* data){});
ASSIGN_OR_RETURN(auto result, ASSIGN_OR_RETURN(auto result,
runner->Run(image_frame, matrix, size, border_mode)); runner->Run(image_frame, matrix, size, border_mode));
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result))); return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));

View File

@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) { ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
// Record the most recent first kept timestamp on any stream. // Record the most recent first kept timestamp on any stream.
for (const auto& stream : input_stream_managers_) { for (const auto& stream : input_stream_managers_) {
int32 queue_size = (stream->QueueSize() >= trigger_queue_size_) int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
? target_queue_size_ ? target_queue_size_
: trigger_queue_size_ - 1; : trigger_queue_size_ - 1;
if (stream->QueueSize() > queue_size) { if (stream->QueueSize() > queue_size) {
kept_timestamp_ = std::max( kept_timestamp_ = std::max(
kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1) kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
} }
private: private:
int32 trigger_queue_size_; int32_t trigger_queue_size_;
int32 target_queue_size_; int32_t target_queue_size_;
bool fixed_min_size_; bool fixed_min_size_;
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and // Indicates that GetNodeReadiness has returned kReadyForProcess once, and
// the corresponding call to FillInputSet has not yet completed. // the corresponding call to FillInputSet has not yet completed.

View File

@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal(
// TODO: Investigate this option in more detail, esp. on Safari. // TODO: Investigate this option in more detail, esp. on Safari.
attrs.preserveDrawingBuffer = 0; attrs.preserveDrawingBuffer = 0;
// Since the Emscripten canvas target finding function is visible from here, // Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
// we hijack findCanvasEventTarget directly for enforcing old Module.canvas // looks for our #canvas target in Module.canvas, where we expect it to be.
// behavior if the user desires, falling back to the new DOM element CSS // -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
// selector behavior next if that is specified, and finally just allowing the // event target behavior, but it was never supposed to be tapping into our
// lookup to proceed on a null target. // canvas anyways. See b/278155946 for more background.
// TODO: Ensure this works with all options (in particular, EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; });
// multithreading options, like the special-case combination of USE_PTHREADS
// and OFFSCREEN_FRAMEBUFFER)
// clang-format off
EM_ASM(
let init_once = true;
if (init_once) {
const cachedFindCanvasEventTarget = findCanvasEventTarget;
if (typeof cachedFindCanvasEventTarget !== 'function') {
if (typeof console !== 'undefined') {
console.error('Expected Emscripten global function '
+ '"findCanvasEventTarget" not found. WebGL context creation '
+ 'may fail.');
}
return;
}
findCanvasEventTarget = function(target) {
if (target == 0) {
if (Module && Module.canvas) {
return Module.canvas;
} else if (Module && Module.canvasCssSelector) {
return cachedFindCanvasEventTarget(Module.canvasCssSelector);
}
if (typeof console !== 'undefined') {
console.warn('Module properties canvas and canvasCssSelector not ' +
'found during WebGL context creation.');
}
}
// We still go through with the find attempt, although for most use
// cases it will not succeed, just in case the user does want to fall-
// back.
return cachedFindCanvasEventTarget(target);
}; // NOLINT: Necessary semicolon.
init_once = false;
}
);
// clang-format on
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle = EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
emscripten_webgl_create_context(nullptr, &attrs); emscripten_webgl_create_context("#canvas", &attrs);
// Check for failure // Check for failure
if (context_handle <= 0) { if (context_handle <= 0) {

View File

@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
int actual_ws = image_frame.WidthStep(); int actual_ws = image_frame.WidthStep();
int alignment = 0; int alignment = 0;
std::unique_ptr<ImageFrame> temp; std::unique_ptr<ImageFrame> temp;
const uint8* data = image_frame.PixelData(); const uint8_t* data = image_frame.PixelData();
// Let's see if the pixel data is tightly aligned to one of the alignments // Let's see if the pixel data is tightly aligned to one of the alignments
// supported by OpenGL, preferring 4 if possible since it's the default. // supported by OpenGL, preferring 4 if possible since it's the default.

View File

@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
GpuBufferFormat format) { GpuBufferFormat format) {
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format); libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment); int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
auto y_data = std::make_unique<uint8[]>(y_stride * height); auto y_data = std::make_unique<uint8_t[]>(y_stride * height);
switch (fourcc) { switch (fourcc) {
case libyuv::FOURCC_NV12: case libyuv::FOURCC_NV12:
case libyuv::FOURCC_NV21: { case libyuv::FOURCC_NV21: {
@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
int uv_width = 2 * std::ceil(0.5f * width); int uv_width = 2 * std::ceil(0.5f * width);
int uv_height = std::ceil(0.5f * height); int uv_height = std::ceil(0.5f * height);
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height); auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
yuv_image_ = std::make_shared<YUVImage>( yuv_image_ = std::make_shared<YUVImage>(
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride, fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
nullptr, 0, width, height); nullptr, 0, width, height);
@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
int uv_width = std::ceil(0.5f * width); int uv_width = std::ceil(0.5f * width);
int uv_height = std::ceil(0.5f * height); int uv_height = std::ceil(0.5f * height);
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height); auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height); auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
yuv_image_ = std::make_shared<YUVImage>( yuv_image_ = std::make_shared<YUVImage>(
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride, fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
std::move(v_data), uv_stride, width, height); std::move(v_data), uv_stride, width, height);

View File

@ -16,6 +16,7 @@ import csv
import filecmp import filecmp
import os import os
import tempfile import tempfile
import unittest
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
import tensorflow as tf import tensorflow as tf
@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier
from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.test import test_utils
@unittest.skip('b/275624089')
class TextClassifierTest(tf.test.TestCase): class TextClassifierTest(tf.test.TestCase):
_AVERAGE_WORD_EMBEDDING_JSON_FILE = ( _AVERAGE_WORD_EMBEDDING_JSON_FILE = (

View File

@ -175,11 +175,7 @@ py_test(
data = [":testdata"], data = [":testdata"],
tags = ["requires-net:external"], tags = ["requires-net:external"],
deps = [ deps = [
":dataset", ":object_detector_import",
":hyperparameters",
":model_spec",
":object_detector",
":object_detector_options",
"//mediapipe/tasks/python/test:test_utils", "//mediapipe/tasks/python/test:test_utils",
], ],
) )

View File

@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from mediapipe.model_maker.python.vision.object_detector import dataset from mediapipe.model_maker.python.vision import object_detector
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
from mediapipe.model_maker.python.vision.object_detector import object_detector
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
from mediapipe.tasks.python.test import test_utils as task_test_utils from mediapipe.tasks.python.test import test_utils as task_test_utils
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
super().setUp() super().setUp()
dataset_folder = task_test_utils.get_test_data_path('coco_data') dataset_folder = task_test_utils.get_test_data_path('coco_data')
cache_dir = self.create_tempdir() cache_dir = self.create_tempdir()
self.data = dataset.Dataset.from_coco_folder( self.data = object_detector.Dataset.from_coco_folder(
dataset_folder, cache_dir=cache_dir dataset_folder, cache_dir=cache_dir
) )
# Mock tempfile.gettempdir() to be unique for each test to avoid race # Mock tempfile.gettempdir() to be unique for each test to avoid race
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.addCleanup(mock_gettempdir.stop) self.addCleanup(mock_gettempdir.stop)
def test_object_detector(self): def test_object_detector(self):
hparams = hyperparameters.HParams( hparams = object_detector.HParams(
epochs=1, epochs=1,
batch_size=2, batch_size=2,
learning_rate=0.9, learning_rate=0.9,
shuffle=False, shuffle=False,
export_dir=self.create_tempdir(), export_dir=self.create_tempdir(),
) )
options = object_detector_options.ObjectDetectorOptions( options = object_detector.ObjectDetectorOptions(
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams supported_model=object_detector.SupportedModels.MOBILENET_V2,
hparams=hparams,
) )
# Test `create`` # Test `create``
model = object_detector.ObjectDetector.create( model = object_detector.ObjectDetector.create(
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
self.assertGreater(os.path.getsize(output_metadata_file), 0) self.assertGreater(os.path.getsize(output_metadata_file), 0)
# Test `quantization_aware_training` # Test `quantization_aware_training`
qat_hparams = hyperparameters.QATHParams( qat_hparams = object_detector.QATHParams(
learning_rate=0.9, learning_rate=0.9,
batch_size=2, batch_size=2,
epochs=1, epochs=1,

View File

@ -24,8 +24,8 @@ namespace mediapipe {
void FrameAnnotationTracker::AddDetectionResult( void FrameAnnotationTracker::AddDetectionResult(
const FrameAnnotation& frame_annotation) { const FrameAnnotation& frame_annotation) {
const int64 time_us = const int64_t time_us =
static_cast<int64>(std::round(frame_annotation.timestamp())); static_cast<int64_t>(std::round(frame_annotation.timestamp()));
for (const auto& object_annotation : frame_annotation.annotations()) { for (const auto& object_annotation : frame_annotation.annotations()) {
detected_objects_[time_us + object_annotation.object_id()] = detected_objects_[time_us + object_annotation.object_id()] =
object_annotation; object_annotation;
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
absl::flat_hash_set<int>* cancel_object_ids) { absl::flat_hash_set<int>* cancel_object_ids) {
CHECK(cancel_object_ids != nullptr); CHECK(cancel_object_ids != nullptr);
FrameAnnotation frame_annotation; FrameAnnotation frame_annotation;
std::vector<int64> keys_to_be_deleted; std::vector<int64_t> keys_to_be_deleted;
for (const auto& detected_obj : detected_objects_) { for (const auto& detected_obj : detected_objects_) {
const int object_id = detected_obj.second.object_id(); const int object_id = detected_obj.second.object_id();
if (cancel_object_ids->contains(object_id)) { if (cancel_object_ids->contains(object_id)) {

View File

@ -78,6 +78,7 @@ cc_library(
hdrs = ["mediapipe_builtin_op_resolver.h"], hdrs = ["mediapipe_builtin_op_resolver.h"],
deps = [ deps = [
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite", "//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h" #include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
AddCustom("KmeansEmbeddingLookup", AddCustom("KmeansEmbeddingLookup",
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup()); mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
// For the UniversalSentenceEncoder model. // For the UniversalSentenceEncoder model.
AddCustom("TFSentencepieceTokenizeOp",
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
AddCustom("RaggedTensorToTensor", AddCustom("RaggedTensorToTensor",
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR()); mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
} }

View File

@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
filegroup( filegroup(
name = "config_fbs", name = "config_fbs",
srcs = ["config.fbs"], srcs = ["config.fbs"],
@ -80,3 +87,86 @@ cc_test(
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
], ],
) )
cc_library(
name = "sentencepiece_constants",
hdrs = ["sentencepiece_constants.h"],
)
cc_library(
name = "model_converter",
srcs = [
"model_converter.cc",
],
hdrs = [
"model_converter.h",
],
deps = [
":config",
":double_array_trie_builder",
":encoder_config",
":sentencepiece_constants",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
],
)
cc_library(
name = "optimized_encoder",
srcs = [
"optimized_encoder.cc",
],
hdrs = [
"optimized_encoder.h",
],
deps = [
":double_array_trie",
":encoder_config",
":utils",
],
)
cc_library(
name = "sentencepiece_tokenizer_tflite",
srcs = ["sentencepiece_tokenizer_tflite.cc"],
hdrs = ["sentencepiece_tokenizer_tflite.h"],
visibility = [
"//visibility:public",
],
deps =
[
":optimized_encoder",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
],
)
cc_test(
name = "optimized_encoder_test",
srcs = [
"optimized_encoder_test.cc",
],
data = [
":testdata",
],
deps = [
":double_array_trie_builder",
":encoder_config",
":model_converter",
":optimized_encoder",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:gtest_main",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
"@com_google_sentencepiece//src:sentencepiece_processor",
"@org_tensorflow//tensorflow/core:lib",
],
)

View File

@ -0,0 +1,131 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
#include "src/sentencepiece_model.pb.h"
namespace mediapipe::tflite_operations::sentencepiece {
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
DecodePrecompiledCharsmap(
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
// This function "undoes" encoding done by
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
const uint32_t trie_size =
*reinterpret_cast<const uint32_t*>(precompiled_map);
const uint32_t* trie_ptr =
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
precompiled_map + sizeof(uint32_t) + trie_size);
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
sizeof(uint32_t) - trie_size;
return std::make_tuple(
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
}
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset) {
::sentencepiece::ModelProto model_config;
if (!model_config.ParseFromString(model_config_str)) {
return absl::InvalidArgumentError(
"Invalid configuration, can't parse SentencePiece model config " +
model_config.InitializationErrorString());
}
// Convert sentencepieces.
std::vector<std::string> pieces;
pieces.reserve(model_config.pieces_size());
std::vector<float> scores;
scores.reserve(model_config.pieces_size());
std::vector<int> ids;
ids.reserve(model_config.pieces_size());
float min_score = 0.0;
int index = 0;
for (const auto& piece : model_config.pieces()) {
switch (piece.type()) {
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
pieces.push_back(piece.piece());
ids.push_back(index);
if (piece.score() < min_score) {
min_score = piece.score();
}
break;
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
// Ignore unknown and control codes.
break;
default:
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
piece.piece());
}
scores.push_back(piece.score());
++index;
}
flatbuffers::FlatBufferBuilder builder(1024);
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
const auto pieces_score_vector = builder.CreateVector(scores);
TrieBuilder pieces_trie_builder(builder);
pieces_trie_builder.add_nodes(pieces_trie_vector);
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
// Converting normalization.
const auto normalization =
DecodePrecompiledCharsmap(model_config.normalizer_spec());
const auto normalization_trie = std::get<0>(normalization);
const auto normalization_strings = std::get<1>(normalization);
const auto normalization_trie_vector =
builder.CreateVector(normalization_trie);
TrieBuilder normalization_trie_builder(builder);
normalization_trie_builder.add_nodes(normalization_trie_vector);
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
const auto normalization_strings_fbs =
builder.CreateVector(normalization_strings);
EncoderConfigBuilder ecb(builder);
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
ecb.add_start_code(model_config.trainer_spec().bos_id());
ecb.add_end_code(model_config.trainer_spec().eos_id());
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
ecb.add_unknown_penalty(min_score - kUnkPenalty);
ecb.add_encoding_offset(encoding_offset);
ecb.add_pieces(pieces_trie_fbs);
ecb.add_pieces_scores(pieces_score_vector);
ecb.add_remove_extra_whitespaces(
model_config.normalizer_spec().remove_extra_whitespaces());
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
ecb.add_escape_whitespaces(
model_config.normalizer_spec().escape_whitespaces());
ecb.add_normalized_prefixes(normalization_trie_fbs);
ecb.add_normalized_replacements(normalization_strings_fbs);
FinishEncoderConfigBuffer(builder, ecb.Finish());
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
}
std::string ConvertSentencepieceModel(const std::string& model_string) {
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
assert(result.status().ok());
return result.value();
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,33 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
#include <string>
#include "absl/status/statusor.h"
namespace mediapipe::tflite_operations::sentencepiece {
// Converts Sentencepiece configuration to flatbuffer format.
// encoding_offset is used by some encoders that combine different encodings.
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
const std::string& model_config_str, int encoding_offset = 0);
std::string ConvertSentencepieceModel(const std::string& model_string);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_

View File

@ -0,0 +1,236 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <algorithm>
#include <tuple>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace {
const char kSpaceSymbol[] = "\xe2\x96\x81";
template <typename processing_callback>
std::tuple<std::string, std::vector<int>> process_string(
const std::string& input, const std::vector<int>& offsets,
const processing_callback& pc) {
std::string result_string;
result_string.reserve(input.size());
std::vector<int> result_offsets;
result_offsets.reserve(offsets.size());
for (int i = 0, j = 0; i < input.size();) {
auto result = pc(input.data() + i, input.size() - i);
auto consumed = std::get<0>(result);
auto new_string = std::get<1>(result);
if (consumed == 0) {
// Skip the current byte and move forward.
result_string.push_back(input[i]);
result_offsets.push_back(offsets[j]);
i++;
j++;
continue;
}
result_string.append(new_string.data(), new_string.length());
for (int i = 0; i < new_string.length(); ++i) {
result_offsets.push_back(offsets[j]);
}
j += consumed;
i += consumed;
}
return std::make_tuple(result_string, result_offsets);
}
inline char is_whitespace(char c) {
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
}
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
int len) {
if (len == 0 || !is_whitespace(*data)) {
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
int num_consumed = 1;
for (; num_consumed < len && is_whitespace(data[num_consumed]);
++num_consumed) {
}
return num_consumed > 1
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
: std::make_tuple(0, utils::string_view(nullptr, 0));
}
std::tuple<int, utils::string_view> find_replacement(
const char* data, int len, const DoubleArrayTrie& dat,
const flatbuffers::Vector<int8_t>& replacements) {
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
if (!max_match.empty()) {
// Because flatbuffer byte is signed char which is not the same as char,
// there is the reinterpret_cast here.
const char* replaced_string_ptr =
reinterpret_cast<const char*>(replacements.data() + max_match.id);
return std::make_tuple(max_match.match_length,
utils::string_view(replaced_string_ptr));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
}
} // namespace
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config) {
std::vector<int> output_offsets;
std::string result = in_string;
output_offsets.reserve(in_string.length());
for (int i = 0; i < in_string.length(); ++i) {
output_offsets.push_back(i);
}
if (in_string.empty()) {
return std::make_tuple(result, output_offsets);
}
if (config.add_dummy_prefix()) {
result.insert(result.begin(), ' ');
output_offsets.insert(output_offsets.begin(), 0);
}
// Greedely replace normalized_prefixes with normalized_replacements
if (config.normalized_prefixes() != nullptr &&
config.normalized_replacements() != nullptr) {
const DoubleArrayTrie normalized_prefixes_matcher(
config.normalized_prefixes()->nodes());
const auto norm_replace = [&config, &normalized_prefixes_matcher](
const char* data, int len) {
return find_replacement(data, len, normalized_prefixes_matcher,
*config.normalized_replacements());
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, norm_replace);
}
if (config.remove_extra_whitespaces()) {
std::tie(result, output_offsets) =
process_string(result, output_offsets, remove_extra_whitespaces);
if (!result.empty() && is_whitespace(result.back())) {
result.pop_back();
output_offsets.pop_back();
}
}
if (config.escape_whitespaces()) {
const auto replace_whitespaces = [](const char* data, int len) {
if (len > 0 && is_whitespace(*data)) {
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
}
return std::make_tuple(0, utils::string_view(nullptr, 0));
};
std::tie(result, output_offsets) =
process_string(result, output_offsets, replace_whitespaces);
}
return std::make_tuple(result, output_offsets);
}
EncoderResult EncodeNormalizedString(const std::string& str,
const std::vector<int>& offsets,
const EncoderConfig& config, bool add_bos,
bool add_eos, bool reverse) {
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
const int unknown_code = config.unknown_code();
const float unknown_penalty = config.unknown_penalty();
struct LatticeElement {
float score = 0;
int code = -1;
int prev_position = -1;
LatticeElement(float score_, int code_, int prev_position_)
: score(score_), code(code_), prev_position(prev_position_) {}
LatticeElement() {}
};
const int length = str.length();
std::vector<LatticeElement> lattice(length + 1);
for (int i = 0; i < length; ++i) {
if (i > 0 && lattice[i].prev_position < 0) {
// This state is unreachable.
continue;
}
if (unknown_code >= 0) {
// Put unknown code.
const float penalized_score = lattice[i].score + unknown_penalty;
const int pos = i + 1;
LatticeElement& current_element = lattice[pos];
if (current_element.prev_position < 0 ||
current_element.score < penalized_score) {
current_element = LatticeElement(
penalized_score, unknown_code,
// If the current state is already reached by unknown code, merge
// states.
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
}
}
auto lattice_update = [&lattice, i,
piece_scores](const DoubleArrayTrie::Match& m) {
LatticeElement& target_element = lattice[i + m.match_length];
const float score = lattice[i].score + (*piece_scores)[m.id];
if (target_element.prev_position < 0 || target_element.score < score) {
target_element = LatticeElement(score, m.id, i);
}
};
piece_matcher.IteratePrefixMatches(
utils::string_view(str.data() + i, length - i), lattice_update);
}
EncoderResult result;
if (add_eos) {
result.codes.push_back(config.end_code());
result.offsets.push_back(length);
}
if (lattice[length].prev_position >= 0) {
for (int pos = length; pos > 0;) {
auto code = lattice[pos].code;
if (code != config.unknown_code()) {
code += config.encoding_offset();
}
result.codes.push_back(code);
pos = lattice[pos].prev_position;
result.offsets.push_back(offsets[pos]);
}
}
if (add_bos) {
result.codes.push_back(config.start_code());
result.offsets.push_back(0);
}
if (!reverse) {
std::reverse(result.codes.begin(), result.codes.end());
std::reverse(result.offsets.begin(), result.offsets.end());
}
return result;
}
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse) {
// Get the config from the buffer.
const EncoderConfig* config = GetEncoderConfig(config_buffer);
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
EncoderResult result;
result.type = EncoderResultType::WRONG_CONFIG;
return result;
}
std::string normalized_string;
std::vector<int> offsets;
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
add_eos, reverse);
}
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,46 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
// Sentencepiece encoder optimized with memmapped model.
#include <string>
#include <tuple>
#include <vector>
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
namespace mediapipe::tflite_operations::sentencepiece {
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
struct EncoderResult {
EncoderResultType type = EncoderResultType::SUCCESS;
std::vector<int> codes;
std::vector<int> offsets;
};
std::tuple<std::string, std::vector<int>> NormalizeString(
const std::string& in_string, const EncoderConfig& config);
// Encodes one string and returns ids and offsets. Takes the configuration as a
// type-erased buffer.
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
bool add_bos, bool add_eos, bool reverse);
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_

View File

@ -0,0 +1,171 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include <fstream>
#include "absl/flags/flag.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
#include "src/sentencepiece.pb.h"
#include "src/sentencepiece_processor.h"
#include "tensorflow/core/platform/env.h"
namespace mediapipe::tflite_operations::sentencepiece {
namespace internal {
tensorflow::Status TFReadFileToString(const std::string& filepath,
std::string* data) {
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
data);
}
absl::Status StdReadFileToString(const std::string& filepath,
std::string* data) {
std::ifstream infile(filepath);
if (!infile.is_open()) {
return absl::NotFoundError(
absl::StrFormat("Error when opening %s", filepath));
}
std::string contents((std::istreambuf_iterator<char>(infile)),
(std::istreambuf_iterator<char>()));
data->append(contents);
infile.close();
return absl::OkStatus();
}
} // namespace internal
namespace {
using ::mediapipe::file::JoinPath;
static char kConfigFilePath[] =
"/mediapipe/tasks/cc/text/custom_ops/"
"sentencepiece/testdata/sentencepiece.model";
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
flatbuffers::FlatBufferBuilder builder(1024);
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_add_dummy_prefix(true);
ecb.add_escape_whitespaces(true);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("x y", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
}
{
const auto result = NormalizeString("\tx y\n", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
}
}
TEST(OptimizedEncoder, NormalizeStringReplacement) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
const char norm_replacements[] = "A1\0A2\0A3\0A4";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(false);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("ABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, "A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
}
}
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
flatbuffers::FlatBufferBuilder builder(1024);
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
"X"};
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
const auto trie_vector =
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
const auto norm_r = builder.CreateVector<int8_t>(
reinterpret_cast<const signed char*>(norm_replacements),
sizeof(norm_replacements));
TrieBuilder trie_builder(builder);
trie_builder.add_nodes(trie_vector);
const auto norm_p = trie_builder.Finish();
EncoderConfigBuilder ecb(builder);
ecb.add_remove_extra_whitespaces(true);
ecb.add_normalized_prefixes(norm_p);
ecb.add_normalized_replacements(norm_r);
FinishEncoderConfigBuffer(builder, ecb.Finish());
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
{
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
const auto res_string = std::get<0>(result);
const auto offsets = std::get<1>(result);
EXPECT_EQ(res_string, " A1BA2BA3BA4");
EXPECT_THAT(offsets,
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
}
}
TEST(OptimizedEncoder, ConfigConverter) {
std::string config;
auto status =
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
ASSERT_TRUE(status.ok());
::sentencepiece::SentencePieceProcessor processor;
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
const auto converted_model = ConvertSentencepieceModel(config);
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
const auto encoded =
EncodeString(test_string, converted_model.data(), false, false, false);
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
::sentencepiece::SentencePieceText reference_encoded;
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
for (int i = 0; i < encoded.codes.size(); ++i) {
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
}
}
} // namespace
} // namespace mediapipe::tflite_operations::sentencepiece

View File

@ -0,0 +1,38 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
namespace mediapipe::tflite_operations::sentencepiece {
// The constant is copied from
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
constexpr float kUnkPenalty = 10.0;
// These constants are copied from
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
//
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
// since this character can be useful both for user and
// developer. We can easily figure out that <unk> is emitted.
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
} // namespace mediapipe::tflite_operations::sentencepiece
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_

View File

@ -0,0 +1,129 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace mediapipe::tflite_operations {
namespace sentencepiece::tokenizer {
namespace {
using ::tflite::SetTensorToDynamic;
constexpr int kSPModelIndex = 0;
constexpr int kInputIndex = 1;
constexpr int kAddBOSInput = 4;
constexpr int kAddEOSInput = 5;
constexpr int kReverseInput = 6;
constexpr int kOutputValuesInd = 0;
constexpr int kOutputSplitsInd = 1;
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
int index = 0;
for (const int size : sizes) {
array_size->data[index++] = size;
}
return array_size;
}
} // namespace
// Initializes text encoder object from serialized parameters.
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
size_t /*length*/) {
return nullptr;
}
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// TODO: Add checks for input and output tensors.
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
SetTensorToDynamic(&output_values);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
SetTensorToDynamic(&output_splits);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor& model_tensor =
context->tensors[node->inputs->data[kSPModelIndex]];
const auto model_buffer_data = model_tensor.data.data;
const TfLiteTensor& input_text =
context->tensors[node->inputs->data[kInputIndex]];
const TfLiteTensor add_bos_tensor =
context->tensors[node->inputs->data[kAddBOSInput]];
const bool add_bos = add_bos_tensor.data.b[0];
const TfLiteTensor add_eos_tensor =
context->tensors[node->inputs->data[kAddEOSInput]];
const bool add_eos = add_eos_tensor.data.b[0];
const TfLiteTensor reverse_tensor =
context->tensors[node->inputs->data[kReverseInput]];
const bool reverse = reverse_tensor.data.b[0];
std::vector<int32> encoded;
std::vector<int32> splits;
const int num_strings = tflite::GetStringCount(&input_text);
for (int i = 0; i < num_strings; ++i) {
const auto strref = tflite::GetString(&input_text, i);
const auto res = EncodeString(std::string(strref.str, strref.len),
model_buffer_data, add_bos, add_eos, reverse);
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
"Sentencepiece conversion failed");
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
splits.emplace_back(encoded.size());
}
TfLiteTensor& output_values =
context->tensors[node->outputs->data[kOutputValuesInd]];
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(
context, &output_values,
CreateSizeArray({static_cast<int>(encoded.size())})));
int32_t* output_values_flat = output_values.data.i32;
std::copy(encoded.begin(), encoded.end(), output_values_flat);
TfLiteTensor& output_splits =
context->tensors[node->outputs->data[kOutputSplitsInd]];
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(
context, &output_splits,
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
int32_t* output_splits_flat = output_splits.data.i32;
*output_splits_flat = 0;
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
return kTfLiteOk;
}
} // namespace sentencepiece::tokenizer
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
static TfLiteRegistration r = {
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
return &r;
}
} // namespace mediapipe::tflite_operations

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe::tflite_operations {
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
} // namespace mediapipe::tflite_operations
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_

View File

@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
// Embedding model with regex preprocessing. // Embedding model with regex preprocessing.
constexpr char kRegexOneEmbeddingModel[] = constexpr char kRegexOneEmbeddingModel[] =
"regex_one_embedding_with_metadata.tflite"; "regex_one_embedding_with_metadata.tflite";
constexpr char kUniversalSentenceEncoderModel[] =
"universal_sentence_encoder_qa_with_metadata.tflite";
// Tolerance for embedding vector coordinate values. // Tolerance for embedding vector coordinate values.
constexpr float kEpsilon = 1e-4; constexpr float kEpsilon = 1e-4;
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
MP_ASSERT_OK(text_embedder->Close()); MP_ASSERT_OK(text_embedder->Close());
} }
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
auto result0,
text_embedder->Embed("it's a charming and often affecting journey"));
ASSERT_EQ(result0.embeddings.size(), 1);
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
MP_ASSERT_OK_AND_ASSIGN(
auto result1, text_embedder->Embed("what a great and fantastic trip"));
ASSERT_EQ(result1.embeddings.size(), 1);
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>(); auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
MP_ASSERT_OK(text_embedder->Close()); MP_ASSERT_OK(text_embedder->Close());
} }
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
auto options = std::make_unique<TextEmbedderOptions>();
options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::Create(std::move(options)));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result0,
text_embedder->Embed("When you go to this restaurant, they hold the "
"pancake upside-down before they hand it "
"to you. It's a great gimmick."));
MP_ASSERT_OK_AND_ASSIGN(
TextEmbedderResult result1,
text_embedder->Embed(
"Let's make a plan to steal the declaration of independence."));
// Check cosine similarity.
MP_ASSERT_OK_AND_ASSIGN(
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
result1.embeddings[0]));
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
MP_ASSERT_OK(text_embedder->Close());
}
} // namespace } // namespace
} // namespace mediapipe::tasks::text::text_embedder } // namespace mediapipe::tasks::text::text_embedder

View File

@ -23,18 +23,12 @@ cc_library(
srcs = ["face_stylizer_graph.cc"], srcs = ["face_stylizer_graph.cc"],
deps = [ deps = [
"//mediapipe/calculators/core:split_vector_calculator_cc_proto", "//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/image:image_cropping_calculator", "//mediapipe/calculators/image:image_clone_calculator_cc_proto",
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
"//mediapipe/calculators/image:warp_affine_calculator",
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
"//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/tensor:inference_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:face_to_rect_calculator", "//mediapipe/calculators/util:face_to_rect_calculator",
"//mediapipe/calculators/util:from_image_calculator",
"//mediapipe/calculators/util:inverse_matrix_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto", "//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
"//mediapipe/calculators/util:to_image_calculator",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
@ -53,7 +47,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",

View File

@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// The input image can be of any size with format RGB or RGBA. // The input image can be of any size with format RGB or RGBA.
// When no face is detected on the input image, the method returns a // When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible // std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the stylized // face. The stylized output image size is the same as the model output size.
// output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
absl::StatusOr<std::optional<mediapipe::Image>> Stylize( absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
mediapipe::Image image, mediapipe::Image image,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// must be monotonically increasing. // must be monotonically increasing.
// When no face is detected on the input image, the method returns a // When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible // std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the stylized // face. The stylized output image size is the same as the model output size.
// output image size is the smaller of the model output size and the size of
// the 'region_of_interest' specified in 'image_processing_options'.
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo( absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
mediapipe::Image image, int64_t timestamp_ms, mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> image_processing_options = std::optional<core::ImageProcessingOptions> image_processing_options =
@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
// The "result_callback" provides: // The "result_callback" provides:
// - When no face is detected on the input image, the method returns a // - When no face is detected on the input image, the method returns a
// std::nullopt. Otherwise, returns the stylized image of the most visible // std::nullopt. Otherwise, returns the stylized image of the most visible
// face. To ensure that the output image has reasonable quality, the // face. The stylized output image size is the same as the model output
// stylized output image size is the smaller of the model output size and // size.
// the size of the 'region_of_interest' specified in
// 'image_processing_options'.
// - The input timestamp in milliseconds. // - The input timestamp in milliseconds.
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms, absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
std::optional<core::ImageProcessingOptions> std::optional<core::ImageProcessingOptions>

View File

@ -19,8 +19,7 @@ limitations under the License.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/calculators/core/split_vector_calculator.pb.h" #include "mediapipe/calculators/core/split_vector_calculator.pb.h"
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h" #include "mediapipe/calculators/image/image_clone_calculator.pb.h"
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h" #include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
image_in >> preprocessing.In(kImageTag); image_in >> preprocessing.In(kImageTag);
face_rect >> preprocessing.In(kNormRectTag); face_rect >> preprocessing.In(kNormRectTag);
auto preprocessed_tensors = preprocessing.Out(kTensorsTag); auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
auto transform_matrix = preprocessing.Out(kMatrixTag);
// Adds inference subgraph and connects its input stream to the output // Adds inference subgraph and connects its input stream to the output
// tensors produced by the ImageToTensorCalculator. // tensors produced by the ImageToTensorCalculator.
@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
model_output_tensors >> tensors_to_image.In(kTensorsTag); model_output_tensors >> tensors_to_image.In(kTensorsTag);
auto tensor_image = tensors_to_image.Out(kImageTag); auto tensor_image = tensors_to_image.Out(kImageTag);
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); auto& image_converter = graph.AddNode("ImageCloneCalculator");
transform_matrix >> inverse_matrix.In(kMatrixTag); image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag); .set_output_on_gpu(false);
tensor_image >> image_converter.In("");
auto& warp_affine = graph.AddNode("WarpAffineCalculator"); return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
auto& warp_affine_options =
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
warp_affine_options.set_border_mode(
WarpAffineCalculatorOptions::BORDER_ZERO);
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
tensor_image >> warp_affine.In(kImageTag);
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
image_size >> warp_affine.In(kOutputSizeTag);
auto image_to_crop = warp_affine.Out(kImageTag);
// The following calculators are for cropping and resizing the output image
// based on the roi and the model output size. As the WarpAffineCalculator
// rotates the image based on the transform matrix, the rotation info in the
// rect proto is stripped to prevent the ImageCroppingCalculator from
// performing extra rotation.
auto& strip_rotation =
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
face_rect >> strip_rotation.In(kNormRectTag);
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
auto& from_image = graph.AddNode("FromImageCalculator");
image_to_crop >> from_image.In(kImageTag);
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
auto& image_cropping_opts =
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
image_cropping_opts.set_output_max_width(
image_to_tensor_options.output_tensor_width());
image_cropping_opts.set_output_max_height(
image_to_tensor_options.output_tensor_height());
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
auto& to_image = graph.AddNode("ToImageCalculator");
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
// graph selects its cpu or gpu path based on the image preprocessing
// backend.
if (use_gpu) {
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
} else {
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
}
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}}; /*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
} }
}; };

View File

@ -100,6 +100,7 @@ cc_library(
"//mediapipe/util:graph_builder_utils", "//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
alwayslink = 1,
) )
cc_library( cc_library(

View File

@ -90,7 +90,7 @@ NS_SWIFT_NAME(ClassificationResult)
* amount of data to process might exceed the maximum size that the model can process: to solve * amount of data to process might exceed the maximum size that the model can process: to solve
* this, the input data is split into multiple chunks starting at different timestamps. * this, the input data is split into multiple chunks starting at different timestamps.
*/ */
@property(nonatomic, readonly) NSInteger timestampMs; @property(nonatomic, readonly) NSInteger timestampInMilliseconds;
/** /**
* Initializes a new `MPPClassificationResult` with the given array of classifications and time * Initializes a new `MPPClassificationResult` with the given array of classifications and time
@ -98,14 +98,15 @@ NS_SWIFT_NAME(ClassificationResult)
* *
* @param classifications An Array of `MPPClassifications` objects containing the predicted * @param classifications An Array of `MPPClassifications` objects containing the predicted
* categories for each head of the model. * categories for each head of the model.
* @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data * @param timestampInMilliseconds The timestamp (in milliseconds) of the start of the chunk of data
* corresponding to these results. * corresponding to these results.
* *
* @return An instance of `MPPClassificationResult` initialized with the given array of * @return An instance of `MPPClassificationResult` initialized with the given array of
* classifications and timestampMs. * classifications and timestamp (in milliseconds).
*/ */
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications - (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; timestampInMilliseconds:(NSInteger)timestampInMilliseconds
NS_DESIGNATED_INITIALIZER;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -38,11 +38,11 @@
@implementation MPPClassificationResult @implementation MPPClassificationResult
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications - (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
timestampMs:(NSInteger)timestampMs { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super init]; self = [super init];
if (self) { if (self) {
_classifications = classifications; _classifications = classifications;
_timestampMs = timestampMs; _timestampInMilliseconds = timestampInMilliseconds;
} }
return self; return self;

View File

@ -33,7 +33,7 @@ NS_SWIFT_NAME(EmbeddingResult)
* cases, the amount of data to process might exceed the maximum size that the model can process. To * cases, the amount of data to process might exceed the maximum size that the model can process. To
* solve this, the input data is split into multiple chunks starting at different timestamps. * solve this, the input data is split into multiple chunks starting at different timestamps.
*/ */
@property(nonatomic, readonly) NSInteger timestampMs; @property(nonatomic, readonly) NSInteger timestampInMilliseconds;
/** /**
* Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in * Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in
@ -41,14 +41,14 @@ NS_SWIFT_NAME(EmbeddingResult)
* *
* @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each * @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each
* head of the model. * head of the model.
* @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data * @param timestampInMilliseconds The optional timestamp (in milliseconds) of the start of the chunk
* corresponding to these results. Pass `0` if timestamp is absent. * of data corresponding to these results. Pass `0` if timestamp is absent.
* *
* @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and * @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and
* timestampMs. * timestamp (in milliseconds).
*/ */
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings - (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -17,11 +17,11 @@
@implementation MPPEmbeddingResult @implementation MPPEmbeddingResult
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings - (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
timestampMs:(NSInteger)timestampMs { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super init]; self = [super init];
if (self) { if (self) {
_embeddings = embeddings; _embeddings = embeddings;
_timestampMs = timestampMs; _timestampInMilliseconds = timestampInMilliseconds;
} }
return self; return self;

View File

@ -55,13 +55,13 @@ using ClassificationResultProto =
[classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]]; [classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]];
} }
NSInteger timestampMs = 0; NSInteger timestampInMilliseconds = 0;
if (classificationResultProto.has_timestamp_ms()) { if (classificationResultProto.has_timestamp_ms()) {
timestampMs = (NSInteger)classificationResultProto.timestamp_ms(); timestampInMilliseconds = (NSInteger)classificationResultProto.timestamp_ms();
} }
return [[MPPClassificationResult alloc] initWithClassifications:classifications return [[MPPClassificationResult alloc] initWithClassifications:classifications
timestampMs:timestampMs]; timestampInMilliseconds:timestampInMilliseconds];
; ;
} }

View File

@ -31,12 +31,13 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::
[embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]]; [embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]];
} }
NSInteger timestampMs = 0; NSInteger timestampInMilliseconds = 0;
if (embeddingResultProto.has_timestamp_ms()) { if (embeddingResultProto.has_timestamp_ms()) {
timestampMs = (NSInteger)embeddingResultProto.timestamp_ms(); timestampInMilliseconds = (NSInteger)embeddingResultProto.timestamp_ms();
} }
return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs]; return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings
timestampInMilliseconds:timestampInMilliseconds];
} }
@end @end

View File

@ -26,11 +26,12 @@ NS_SWIFT_NAME(TaskResult)
/** /**
* Timestamp that is associated with the task result object. * Timestamp that is associated with the task result object.
*/ */
@property(nonatomic, assign, readonly) NSInteger timestampMs; @property(nonatomic, assign, readonly) NSInteger timestampInMilliseconds;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; - (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds
NS_DESIGNATED_INITIALIZER;
@end @end

View File

@ -16,16 +16,16 @@
@implementation MPPTaskResult @implementation MPPTaskResult
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { - (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super init]; self = [super init];
if (self) { if (self) {
_timestampMs = timestampMs; _timestampInMilliseconds = timestampInMilliseconds;
} }
return self; return self;
} }
- (id)copyWithZone:(NSZone *)zone { - (id)copyWithZone:(NSZone *)zone {
return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; return [[MPPTaskResult alloc] initWithTimestampInMilliseconds:self.timestampInMilliseconds];
} }
@end @end

View File

@ -487,7 +487,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSError *liveStreamApiCallError; NSError *liveStreamApiCallError;
XCTAssertFalse([imageClassifier classifyAsyncImage:image XCTAssertFalse([imageClassifier classifyAsyncImage:image
timestampMs:0 timestampInMilliseconds:0
error:&liveStreamApiCallError]); error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError = NSError *expectedLiveStreamApiCallError =
@ -501,7 +501,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError); AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
NSError *videoApiCallError; NSError *videoApiCallError;
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]); XCTAssertFalse([imageClassifier classifyVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError = NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain [NSError errorWithDomain:kExpectedErrorDomain
@ -524,7 +526,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
NSError *liveStreamApiCallError; NSError *liveStreamApiCallError;
XCTAssertFalse([imageClassifier classifyAsyncImage:image XCTAssertFalse([imageClassifier classifyAsyncImage:image
timestampMs:0 timestampInMilliseconds:0
error:&liveStreamApiCallError]); error:&liveStreamApiCallError]);
NSError *expectedLiveStreamApiCallError = NSError *expectedLiveStreamApiCallError =
@ -575,7 +577,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
AssertEqualErrors(imageApiCallError, expectedImageApiCallError); AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
NSError *videoApiCallError; NSError *videoApiCallError;
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]); XCTAssertFalse([imageClassifier classifyVideoFrame:image
timestampInMilliseconds:0
error:&videoApiCallError]);
NSError *expectedVideoApiCallError = NSError *expectedVideoApiCallError =
[NSError errorWithDomain:kExpectedErrorDomain [NSError errorWithDomain:kExpectedErrorDomain
@ -601,7 +605,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image
timestampMs:i timestampInMilliseconds:i
error:nil]; error:nil];
[self assertImageClassifierResult:imageClassifierResult [self assertImageClassifierResult:imageClassifierResult
hasExpectedCategoriesCount:maxResults hasExpectedCategoriesCount:maxResults
@ -630,10 +634,10 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
MPPImage *image = [self imageWithFileInfo:kBurgerImage]; MPPImage *image = [self imageWithFileInfo:kBurgerImage];
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:1 error:nil]); XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]);
NSError *error; NSError *error;
XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampMs:0 error:&error]); XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampInMilliseconds:0 error:&error]);
NSError *expectedError = NSError *expectedError =
[NSError errorWithDomain:kExpectedErrorDomain [NSError errorWithDomain:kExpectedErrorDomain
@ -668,7 +672,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
MPPImage *image = [self imageWithFileInfo:kBurgerImage]; MPPImage *image = [self imageWithFileInfo:kBurgerImage];
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:i error:nil]); XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]);
} }
} }

View File

@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextClassifierResult)
* *
* @param classificationResult The `MPPClassificationResult` instance containing one set of results * @param classificationResult The `MPPClassificationResult` instance containing one set of results
* per classifier head. * per classifier head.
* @param timestampMs The timestamp for this result. * @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
* *
* @return An instance of `MPPTextClassifierResult` initialized with the given * @return An instance of `MPPTextClassifierResult` initialized with the given
* `MPPClassificationResult` and timestamp (in milliseconds). * `MPPClassificationResult` and timestamp (in milliseconds).
*/ */
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs; timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end @end

View File

@ -17,8 +17,8 @@
@implementation MPPTextClassifierResult @implementation MPPTextClassifierResult
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampMs:timestampMs]; self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) { if (self) {
_classificationResult = classificationResult; _classificationResult = classificationResult;
} }

View File

@ -35,7 +35,7 @@ using ::mediapipe::Packet;
return [[MPPTextClassifierResult alloc] return [[MPPTextClassifierResult alloc]
initWithClassificationResult:classificationResult initWithClassificationResult:classificationResult
timestampMs:(NSInteger)(packet.Timestamp().Value() / timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)]; kMicroSecondsPerMilliSecond)];
} }

View File

@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextEmbedderResult)
* *
* @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results * @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results
* per classifier head. * per classifier head.
* @param timestampMs The timestamp for this result. * @param timestampInMilliseconds The timestamp (in millisecondss) for this result.
* *
* @return An instance of `MPPTextEmbedderResult` initialized with the given * @return An instance of `MPPTextEmbedderResult` initialized with the given
* `MPPEmbeddingResult` and timestamp (in milliseconds). * `MPPEmbeddingResult` and timestamp (in milliseconds).
*/ */
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult - (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
timestampMs:(NSInteger)timestampMs; timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -17,8 +17,8 @@
@implementation MPPTextEmbedderResult @implementation MPPTextEmbedderResult
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult - (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
timestampMs:(NSInteger)timestampMs { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampMs:timestampMs]; self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) { if (self) {
_embeddingResult = embeddingResult; _embeddingResult = embeddingResult;
} }

View File

@ -34,7 +34,7 @@ using ::mediapipe::Packet;
return [[MPPTextEmbedderResult alloc] return [[MPPTextEmbedderResult alloc]
initWithEmbeddingResult:embeddingResult initWithEmbeddingResult:embeddingResult
timestampMs:(NSInteger)(packet.Timestamp().Value() / timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)]; kMicroSecondsPerMilliSecond)];
} }

View File

@ -41,7 +41,7 @@
* timestamp. * timestamp.
* *
* @param image The image to send to the MediaPipe graph. * @param image The image to send to the MediaPipe graph.
* @param timestampMs The timestamp (in milliseconds) to assign to the packet. * @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
* error will be saved. * error will be saved.
* *
@ -49,7 +49,7 @@
* occurred during the conversion. * occurred during the conversion.
*/ */
+ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image + (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error; error:(NSError **)error;
/** /**
@ -66,11 +66,11 @@
* specified timestamp. * specified timestamp.
* *
* @param image The `NormalizedRect` to send to the MediaPipe graph. * @param image The `NormalizedRect` to send to the MediaPipe graph.
* @param timestampMs The timestamp (in milliseconds) to assign to the packet. * @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
* *
* @return The MediaPipe packet containing the normalized rect. * @return The MediaPipe packet containing the normalized rect.
*/ */
+ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect + (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect
timestampMs:(NSInteger)timestampMs; timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end @end

View File

@ -42,7 +42,7 @@ using ::mediapipe::Timestamp;
} }
+ (Packet)createPacketWithMPPImage:(MPPImage *)image + (Packet)createPacketWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error { error:(NSError **)error {
std::unique_ptr<ImageFrame> imageFrame = [image imageFrameWithError:error]; std::unique_ptr<ImageFrame> imageFrame = [image imageFrameWithError:error];
@ -51,7 +51,7 @@ using ::mediapipe::Timestamp;
} }
return MakePacket<Image>(std::move(imageFrame)) return MakePacket<Image>(std::move(imageFrame))
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); .At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
} }
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect { + (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect {
@ -59,9 +59,9 @@ using ::mediapipe::Timestamp;
} }
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect + (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect
timestampMs:(NSInteger)timestampMs { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
return MakePacket<NormalizedRect>(std::move(normalizedRect)) return MakePacket<NormalizedRect>(std::move(normalizedRect))
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); .At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
} }
@end @end

View File

@ -21,7 +21,7 @@
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
gestures:(NSArray<NSArray<MPPCategory *> *> *)gestures gestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
timestampInMilliseconds:(NSInteger)timestampInMilliseconds { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampMs:timestampInMilliseconds]; self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) { if (self) {
_landmarks = landmarks; _landmarks = landmarks;
_worldLandmarks = worldLandmarks; _worldLandmarks = worldLandmarks;

View File

@ -122,17 +122,17 @@ NS_SWIFT_NAME(ImageClassifier)
* `MPPRunningModeVideo`. * `MPPRunningModeVideo`.
* *
* @param image The `MPPImage` on which image classification is to be performed. * @param image The `MPPImage` on which image classification is to be performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* monotonically increasing. * timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing image * @param error An optional error parameter populated when there is an error in performing image
* classification on the input video frame. * classification on the input video frame.
* *
* @return An `MPPImageClassifierResult` object that contains a list of image classifications. * @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/ */
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error error:(NSError **)error
NS_SWIFT_NAME(classify(videoFrame:timestampMs:)); NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:));
/** /**
* Performs image classification on the provided video frame of type `MPPImage` cropped to the * Performs image classification on the provided video frame of type `MPPImage` cropped to the
@ -145,8 +145,8 @@ NS_SWIFT_NAME(ImageClassifier)
* *
* @param image A live stream image data of type `MPPImage` on which image classification is to be * @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed. * performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* monotonically increasing. * timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the video frame of type * @param roi A `CGRect` specifying the region of interest within the video frame of type
* `MPPImage`, on which image classification should be performed. * `MPPImage`, on which image classification should be performed.
* @param error An optional error parameter populated when there is an error in performing image * @param error An optional error parameter populated when there is an error in performing image
@ -155,10 +155,10 @@ NS_SWIFT_NAME(ImageClassifier)
* @return An `MPPImageClassifierResult` object that contains a list of image classifications. * @return An `MPPImageClassifierResult` object that contains a list of image classifications.
*/ */
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error error:(NSError **)error
NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:)); NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:regionOfInterest:));
/** /**
* Sends live stream image data of type `MPPImage` to perform image classification using the whole * Sends live stream image data of type `MPPImage` to perform image classification using the whole
@ -172,16 +172,17 @@ NS_SWIFT_NAME(ImageClassifier)
* *
* @param image A live stream image data of type `MPPImage` on which image classification is to be * @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed. * performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* to the image classifier. The input timestamps must be monotonically increasing. * image is sent to the image classifier. The input timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing image * @param error An optional error parameter populated when there is an error in performing image
* classification on the input live stream image data. * classification on the input live stream image data.
* *
* @return `YES` if the image was sent to the task successfully, otherwise `NO`. * @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/ */
- (BOOL)classifyAsyncImage:(MPPImage *)image - (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:)); error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:));
/** /**
* Sends live stream image data of type `MPPImage` to perform image classification, cropped to the * Sends live stream image data of type `MPPImage` to perform image classification, cropped to the
@ -195,8 +196,8 @@ NS_SWIFT_NAME(ImageClassifier)
* *
* @param image A live stream image data of type `MPPImage` on which image classification is to be * @param image A live stream image data of type `MPPImage` on which image classification is to be
* performed. * performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* to the image classifier. The input timestamps must be monotonically increasing. * image is sent to the image classifier. The input timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the given live stream image data * @param roi A `CGRect` specifying the region of interest within the given live stream image data
* of type `MPPImage`, on which image classification should be performed. * of type `MPPImage`, on which image classification should be performed.
* @param error An optional error parameter populated when there is an error in performing image * @param error An optional error parameter populated when there is an error in performing image
@ -205,10 +206,10 @@ NS_SWIFT_NAME(ImageClassifier)
* @return `YES` if the image was sent to the task successfully, otherwise `NO`. * @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/ */
- (BOOL)classifyAsyncImage:(MPPImage *)image - (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error error:(NSError **)error
NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:)); NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:regionOfInterest:));
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -149,7 +149,7 @@ static NSString *const kTaskGraphName =
} }
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image - (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<NormalizedRect> rect = std::optional<NormalizedRect> rect =
@ -162,14 +162,15 @@ static NSString *const kTaskGraphName =
} }
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
error:error]; error:error];
if (imagePacket.IsEmpty()) { if (imagePacket.IsEmpty()) {
return std::nullopt; return std::nullopt;
} }
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() Packet normalizedRectPacket =
timestampMs:timestampMs]; [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampInMilliseconds:timestampInMilliseconds];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap; return inputPacketMap;
@ -180,11 +181,11 @@ static NSString *const kTaskGraphName =
} }
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi regionOfInterest:roi
error:error]; error:error];
if (!inputPacketMap.has_value()) { if (!inputPacketMap.has_value()) {
@ -204,20 +205,20 @@ static NSString *const kTaskGraphName =
} }
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error { error:(NSError **)error {
return [self classifyVideoFrame:image return [self classifyVideoFrame:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero regionOfInterest:CGRectZero
error:error]; error:error];
} }
- (BOOL)classifyAsyncImage:(MPPImage *)image - (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi regionOfInterest:roi
error:error]; error:error];
if (!inputPacketMap.has_value()) { if (!inputPacketMap.has_value()) {
@ -228,10 +229,10 @@ static NSString *const kTaskGraphName =
} }
- (BOOL)classifyAsyncImage:(MPPImage *)image - (BOOL)classifyAsyncImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error { error:(NSError **)error {
return [self classifyAsyncImage:image return [self classifyAsyncImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero regionOfInterest:CGRectZero
error:error]; error:error];
} }

View File

@ -31,13 +31,13 @@ NS_SWIFT_NAME(ImageClassifierResult)
* *
* @param classificationResult The `MPPClassificationResult` instance containing one set of results * @param classificationResult The `MPPClassificationResult` instance containing one set of results
* per classifier head. * per classifier head.
* @param timestampMs The timestamp for this result. * @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
* *
* @return An instance of `MPPImageClassifierResult` initialized with the given * @return An instance of `MPPImageClassifierResult` initialized with the given
* `MPPClassificationResult` and timestamp (in milliseconds). * `MPPClassificationResult` and timestamp (in milliseconds).
*/ */
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs; timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end @end

View File

@ -17,8 +17,8 @@
@implementation MPPImageClassifierResult @implementation MPPImageClassifierResult
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
timestampMs:(NSInteger)timestampMs { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampMs:timestampMs]; self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) { if (self) {
_classificationResult = classificationResult; _classificationResult = classificationResult;
} }

View File

@ -34,7 +34,7 @@ using ::mediapipe::Packet;
return [[MPPImageClassifierResult alloc] return [[MPPImageClassifierResult alloc]
initWithClassificationResult:classificationResult initWithClassificationResult:classificationResult
timestampMs:(NSInteger)(packet.Timestamp().Value() / timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)]; kMicroSecondsPerMilliSecond)];
} }

View File

@ -36,13 +36,13 @@ NS_SWIFT_NAME(ObjectDetectionResult)
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is * @param detections An array of `MPPDetection` objects each of which has a bounding box that is
* expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width) * expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width)
* x [0,image_height)`, which are the dimensions of the underlying image data. * x [0,image_height)`, which are the dimensions of the underlying image data.
* @param timestampMs The timestamp for this result. * @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
* *
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections * @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
* and timestamp (in milliseconds). * and timestamp (in milliseconds).
*/ */
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections - (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampMs:(NSInteger)timestampMs; timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
@end @end

View File

@ -17,8 +17,8 @@
@implementation MPPObjectDetectionResult @implementation MPPObjectDetectionResult
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections - (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
timestampMs:(NSInteger)timestampMs { timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
self = [super initWithTimestampMs:timestampMs]; self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
if (self) { if (self) {
_detections = detections; _detections = detections;
} }

View File

@ -138,8 +138,8 @@ NS_SWIFT_NAME(ObjectDetector)
* `MPPRunningModeVideo`. * `MPPRunningModeVideo`.
* *
* @param image The `MPPImage` on which object detection is to be performed. * @param image The `MPPImage` on which object detection is to be performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* monotonically increasing. * timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing object * @param error An optional error parameter populated when there is an error in performing object
* detection on the input image. * detection on the input image.
* *
@ -149,9 +149,9 @@ NS_SWIFT_NAME(ObjectDetector)
* image data. * image data.
*/ */
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampMs:)); NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
/** /**
* Performs object detection on the provided video frame of type `MPPImage` cropped to the * Performs object detection on the provided video frame of type `MPPImage` cropped to the
@ -164,8 +164,8 @@ NS_SWIFT_NAME(ObjectDetector)
* *
* @param image A live stream image data of type `MPPImage` on which object detection is to be * @param image A live stream image data of type `MPPImage` on which object detection is to be
* performed. * performed.
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
* monotonically increasing. * timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which * @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which
* object detection should be performed. * object detection should be performed.
* *
@ -178,10 +178,10 @@ NS_SWIFT_NAME(ObjectDetector)
* image data. * image data.
*/ */
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error error:(NSError **)error
NS_SWIFT_NAME(detect(videoFrame:timestampMs:regionOfInterest:)); NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:regionOfInterest:));
/** /**
* Sends live stream image data of type `MPPImage` to perform object detection using the whole * Sends live stream image data of type `MPPImage` to perform object detection using the whole
@ -195,16 +195,17 @@ NS_SWIFT_NAME(ObjectDetector)
* *
* @param image A live stream image data of type `MPPImage` on which object detection is to be * @param image A live stream image data of type `MPPImage` on which object detection is to be
* performed. * performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* to the object detector. The input timestamps must be monotonically increasing. * image is sent to the object detector. The input timestamps must be monotonically increasing.
* @param error An optional error parameter populated when there is an error in performing object * @param error An optional error parameter populated when there is an error in performing object
* detection on the input live stream image data. * detection on the input live stream image data.
* *
* @return `YES` if the image was sent to the task successfully, otherwise `NO`. * @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/ */
- (BOOL)detectAsyncInImage:(MPPImage *)image - (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error NS_SWIFT_NAME(detectAsync(image:timestampMs:)); error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
/** /**
* Sends live stream image data of type `MPPImage` to perform object detection, cropped to the * Sends live stream image data of type `MPPImage` to perform object detection, cropped to the
@ -218,8 +219,8 @@ NS_SWIFT_NAME(ObjectDetector)
* *
* @param image A live stream image data of type `MPPImage` on which object detection is to be * @param image A live stream image data of type `MPPImage` on which object detection is to be
* performed. * performed.
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
* to the object detector. The input timestamps must be monotonically increasing. * image is sent to the object detector. The input timestamps must be monotonically increasing.
* @param roi A `CGRect` specifying the region of interest within the given live stream image data * @param roi A `CGRect` specifying the region of interest within the given live stream image data
* of type `MPPImage`, on which iobject detection should be performed. * of type `MPPImage`, on which iobject detection should be performed.
* @param error An optional error parameter populated when there is an error in performing object * @param error An optional error parameter populated when there is an error in performing object
@ -228,10 +229,10 @@ NS_SWIFT_NAME(ObjectDetector)
* @return `YES` if the image was sent to the task successfully, otherwise `NO`. * @return `YES` if the image was sent to the task successfully, otherwise `NO`.
*/ */
- (BOOL)detectAsyncInImage:(MPPImage *)image - (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error error:(NSError **)error
NS_SWIFT_NAME(detectAsync(image:timestampMs:regionOfInterest:)); NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:regionOfInterest:));
- (instancetype)init NS_UNAVAILABLE; - (instancetype)init NS_UNAVAILABLE;

View File

@ -157,7 +157,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
} }
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image - (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<NormalizedRect> rect = std::optional<NormalizedRect> rect =
@ -170,14 +170,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
} }
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
error:error]; error:error];
if (imagePacket.IsEmpty()) { if (imagePacket.IsEmpty()) {
return std::nullopt; return std::nullopt;
} }
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() Packet normalizedRectPacket =
timestampMs:timestampMs]; [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
timestampInMilliseconds:timestampInMilliseconds];
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
return inputPacketMap; return inputPacketMap;
@ -188,11 +189,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
} }
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi regionOfInterest:roi
error:error]; error:error];
if (!inputPacketMap.has_value()) { if (!inputPacketMap.has_value()) {
@ -212,20 +213,20 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
} }
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error { error:(NSError **)error {
return [self detectInVideoFrame:image return [self detectInVideoFrame:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero regionOfInterest:CGRectZero
error:error]; error:error];
} }
- (BOOL)detectAsyncInImage:(MPPImage *)image - (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
regionOfInterest:(CGRect)roi regionOfInterest:(CGRect)roi
error:(NSError **)error { error:(NSError **)error {
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:roi regionOfInterest:roi
error:error]; error:error];
if (!inputPacketMap.has_value()) { if (!inputPacketMap.has_value()) {
@ -236,10 +237,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
} }
- (BOOL)detectAsyncInImage:(MPPImage *)image - (BOOL)detectAsyncInImage:(MPPImage *)image
timestampMs:(NSInteger)timestampMs timestampInMilliseconds:(NSInteger)timestampInMilliseconds
error:(NSError **)error { error:(NSError **)error {
return [self detectAsyncInImage:image return [self detectAsyncInImage:image
timestampMs:timestampMs timestampInMilliseconds:timestampInMilliseconds
regionOfInterest:CGRectZero regionOfInterest:CGRectZero
error:error]; error:error];
} }

View File

@ -38,8 +38,9 @@ using ::mediapipe::Packet;
} }
return [[MPPObjectDetectionResult alloc] return [[MPPObjectDetectionResult alloc]
initWithDetections:detections initWithDetections:detections
timestampMs:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
kMicroSecondsPerMilliSecond)];
} }
@end @end

View File

@ -198,9 +198,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the * <p>The input image can be of any size. The output image is the stylized image with the most
* size of the stylized output is based the model output size and can be smaller than the input * visible face. The stylized output image size is the same as the model output size. When no face
* image. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created
@ -220,9 +220,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size. To ensure that the output image has reasonable quality, * <p>The input image can be of any size. The output image is the stylized image with the most
* the stylized output image size is the smaller of the model output size and the size of the * visible face. The stylized output image size is the same as the model output size. When no face
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
@ -256,9 +256,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the * <p>The input image can be of any size. The output image is the stylized image with the most
* size of the stylized output is based the model output size and can be smaller than the input * visible face. The stylized output image size is the same as the model output size. When no face
* image. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
@ -281,9 +281,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size. To ensure that the output image has reasonable quality, * <p>The input image can be of any size. The output image is the stylized image with the most
* the stylized output image size is the smaller of the model output size and the size of the * visible face. The stylized output image size is the same as the model output size. When no face
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
@ -320,9 +320,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the * <p>The input image can be of any size. The output image is the stylized image with the most
* size of the stylized output is based the model output size and can be smaller than the input * visible face. The stylized output image size is the same as the model output size. When no face
* image. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
@ -346,9 +346,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size. To ensure that the output image has reasonable quality, * <p>The input image can be of any size. The output image is the stylized image with the most
* the stylized output image size is the smaller of the model output size and the size of the * visible face. The stylized output image size is the same as the model output size. When no face
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * is detected on the input image, returns {@code Optional.empty()}. *
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
@ -387,9 +387,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the * <p>The input image can be of any size. The output image is the stylized image with the most
* size of the stylized output is based the model output size and can be smaller than the input * visible face. The stylized output image size is the same as the model output size. When no face
* image. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
@ -414,9 +414,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size. To ensure that the output image has reasonable quality, * <p>The input image can be of any size. The output image is the stylized image with the most
* the stylized output image size is the smaller of the model output size and the size of the * visible face. The stylized output image size is the same as the model output size. When no face
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param timestampMs the input timestamp (in milliseconds). * @param timestampMs the input timestamp (in milliseconds).
@ -445,9 +445,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* *
* <p>{@link FaceStylizer} supports the following color space types: * <p>{@link FaceStylizer} supports the following color space types:
* *
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the * <p>The input image can be of any size. The output image is the stylized image with the most
* size of the stylized output is based the model output * size and can be smaller than the input * visible face. The stylized output image size is the same as the model output size. When no face
* image. * is detected on the input image, returns {@code Optional.empty()}.
* *
* <ul> * <ul>
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
@ -475,9 +475,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888} * <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
* </ul> * </ul>
* *
* <p>The input image can be of any size. To ensure that the output image has reasonable quality, * <p>The input image can be of any size. The output image is the stylized image with the most
* the stylized output image size is the smaller of the model output size and the size of the * visible face. The stylized output image size is the same as the model output size. When no face
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. * is detected on the input image, returns {@code Optional.empty()}.
* *
* @param image a MediaPipe {@link MPImage} object for processing. * @param image a MediaPipe {@link MPImage} object for processing.
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the

View File

@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
"IMAGE:" + IMAGE_IN_STREAM_NAME, "IMAGE:" + IMAGE_IN_STREAM_NAME,
"ROI:" + ROI_IN_STREAM_NAME, "ROI:" + ROI_IN_STREAM_NAME,
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS = private static final int IMAGE_OUT_STREAM_INDEX = 0;
Collections.unmodifiableList(
Arrays.asList(
"GROUPED_SEGMENTATION:segmented_mask_out",
"IMAGE:image_out",
"SEGMENTATION:0:segmentation"));
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
private static final String TASK_GRAPH_NAME = private static final String TASK_GRAPH_NAME =
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
*/ */
public static InteractiveSegmenter createFromOptions( public static InteractiveSegmenter createFromOptions(
Context context, InteractiveSegmenterOptions segmenterOptions) { Context context, InteractiveSegmenterOptions segmenterOptions) {
if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
throw new IllegalArgumentException(
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
}
List<String> outputStreams = new ArrayList<>();
outputStreams.add("IMAGE:image_out");
if (segmenterOptions.outputConfidenceMasks()) {
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
}
final int confidenceMasksOutStreamIndex = outputStreams.size() - 1;
if (segmenterOptions.outputCategoryMask()) {
outputStreams.add("CATEGORY_MASK:category_mask");
}
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
// TODO: Consolidate OutputHandler and TaskRunner. // TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>(); OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
handler.setOutputPacketConverter( handler.setOutputPacketConverter(
@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
@Override @Override
public ImageSegmenterResult convertToTaskResult(List<Packet> packets) public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
throws MediaPipeException { throws MediaPipeException {
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.empty(), Optional.empty(),
Optional.empty(), Optional.empty(),
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
} }
List<MPImage> segmentedMasks = new ArrayList<>(); // If resultListener is not provided, the resulted MPImage is deep copied from
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); // memory.
int imageFormat = boolean copyImage = !segmenterOptions.resultListener().isPresent();
segmenterOptions.outputType() Optional<List<MPImage>> confidenceMasks = Optional.empty();
== InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK if (segmenterOptions.outputConfidenceMasks()) {
? MPImage.IMAGE_FORMAT_VEC32F1 confidenceMasks = Optional.of(new ArrayList<>());
: MPImage.IMAGE_FORMAT_ALPHA; int width =
int imageListSize = PacketGetter.getImageWidthFromImageList(
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); packets.get(confidenceMasksOutStreamIndex));
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; int height =
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe PacketGetter.getImageHeightFromImageList(
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory. packets.get(confidenceMasksOutStreamIndex));
if (!segmenterOptions.resultListener().isPresent()) { int imageListSize =
for (int i = 0; i < imageListSize; i++) { PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
buffersArray[i] = ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
ByteBuffer.allocateDirect( // confidence masks are float type image.
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); final int numBytes = 4;
if (copyImage) {
for (int i = 0; i < imageListSize; i++) {
buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes);
}
}
if (!PacketGetter.getImageList(
packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
throw new MediaPipeException(
MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting confidence masks.");
}
for (ByteBuffer buffer : buffersArray) {
ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
confidenceMasks.get().add(builder.build());
} }
} }
if (!PacketGetter.getImageList( Optional<MPImage> categoryMask = Optional.empty();
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), if (segmenterOptions.outputCategoryMask()) {
buffersArray, int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
!segmenterOptions.resultListener().isPresent())) { int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
throw new MediaPipeException( ByteBuffer buffer;
MediaPipeException.StatusCode.INTERNAL.ordinal(), if (copyImage) {
"There is an error getting segmented masks. It usually results from incorrect" buffer = ByteBuffer.allocateDirect(width * height);
+ " options of unsupported OutputType of given model."); if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
} throw new MediaPipeException(
for (ByteBuffer buffer : buffersArray) { MediaPipeException.StatusCode.INTERNAL.ordinal(),
"There is an error getting category mask.");
}
} else {
buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
}
ByteBufferImageBuilder builder = ByteBufferImageBuilder builder =
new ByteBufferImageBuilder(buffer, width, height, imageFormat); new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
segmentedMasks.add(builder.build()); categoryMask = Optional.of(builder.build());
} }
return ImageSegmenterResult.create( return ImageSegmenterResult.create(
Optional.of(segmentedMasks), confidenceMasks,
Optional.empty(), categoryMask,
BaseVisionTaskApi.generateResultTimestampMs( BaseVisionTaskApi.generateResultTimestampMs(
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
} }
@Override @Override
@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
.setTaskRunningModeName(RunningMode.IMAGE.name()) .setTaskRunningModeName(RunningMode.IMAGE.name())
.setTaskGraphName(TASK_GRAPH_NAME) .setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS) .setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS) .setOutputStreams(outputStreams)
.setTaskOptions(segmenterOptions) .setTaskOptions(segmenterOptions)
.setEnableFlowLimiting(false) .setEnableFlowLimiting(false)
.build(), .build(),
@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
/** Sets the base options for the image segmenter task. */ /** Sets the base options for the image segmenter task. */
public abstract Builder setBaseOptions(BaseOptions value); public abstract Builder setBaseOptions(BaseOptions value);
/** The output type from image segmenter. */ /** Sets whether to output confidence masks. Default to true. */
public abstract Builder setOutputType(OutputType value); public abstract Builder setOutputConfidenceMasks(boolean value);
/** Sets whether to output category mask. Default to false. */
public abstract Builder setOutputCategoryMask(boolean value);
/** /**
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph * Sets an optional {@link ResultListener} to receive the segmentation results when the graph
@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
abstract BaseOptions baseOptions(); abstract BaseOptions baseOptions();
abstract OutputType outputType(); abstract boolean outputConfidenceMasks();
abstract boolean outputCategoryMask();
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener(); abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
abstract Optional<ErrorListener> errorListener(); abstract Optional<ErrorListener> errorListener();
/** The output type of segmentation results. */
public enum OutputType {
// Gives a single output mask where each pixel represents the class which
// the pixel in the original image was predicted to belong to.
CATEGORY_MASK,
// Gives a list of output masks where, for each mask, each pixel represents
// the prediction confidence, usually in the [0, 1] range.
CONFIDENCE_MASK
}
public static Builder builder() { public static Builder builder() {
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder() return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
.setOutputType(OutputType.CATEGORY_MASK); .setOutputConfidenceMasks(true)
.setOutputCategoryMask(false);
} }
/** /**
@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
SegmenterOptionsProto.SegmenterOptions.newBuilder(); SegmenterOptionsProto.SegmenterOptions.newBuilder();
if (outputType() == OutputType.CONFIDENCE_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
} else if (outputType() == OutputType.CATEGORY_MASK) {
segmenterOptionsBuilder.setOutputType(
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
}
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
return CalculatorOptions.newBuilder() return CalculatorOptions.newBuilder()
.setExtension( .setExtension(

View File

@ -234,8 +234,8 @@ public class FaceStylizerTest {
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage); FaceStylizerResult actualResult = faceStylizer.stylize(inputImage);
MPImage stylizedImage = actualResult.stylizedImage().get(); MPImage stylizedImage = actualResult.stylizedImage().get();
assertThat(stylizedImage).isNotNull(); assertThat(stylizedImage).isNotNull();
assertThat(stylizedImage.getWidth()).isEqualTo(83); assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
assertThat(stylizedImage.getHeight()).isEqualTo(83); assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
} }
@Test @Test

View File

@ -53,18 +53,15 @@ public class InteractiveSegmenterTest {
InteractiveSegmenterOptions options = InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder() InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK) .setOutputConfidenceMasks(false)
.setOutputCategoryMask(true)
.build(); .build();
InteractiveSegmenter imageSegmenter = InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions( InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options); ApplicationProvider.getApplicationContext(), options);
MPImage image = getImageFromAsset(inputImageName); MPImage image = getImageFromAsset(inputImageName);
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
// TODO update to correct category mask output. assertThat(actualResult.categoryMask().isPresent()).isTrue();
// After InteractiveSegmenter updated according to (b/276519300), update this to use
// categoryMask field instead of confidenceMasks.
List<MPImage> segmentations = actualResult.confidenceMasks().get();
assertThat(segmentations.size()).isEqualTo(1);
} }
@Test @Test
@ -75,15 +72,17 @@ public class InteractiveSegmenterTest {
InteractiveSegmenterOptions options = InteractiveSegmenterOptions options =
InteractiveSegmenterOptions.builder() InteractiveSegmenterOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
.setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK) .setOutputConfidenceMasks(true)
.setOutputCategoryMask(false)
.build(); .build();
InteractiveSegmenter imageSegmenter = InteractiveSegmenter imageSegmenter =
InteractiveSegmenter.createFromOptions( InteractiveSegmenter.createFromOptions(
ApplicationProvider.getApplicationContext(), options); ApplicationProvider.getApplicationContext(), options);
ImageSegmenterResult actualResult = ImageSegmenterResult actualResult =
imageSegmenter.segment(getImageFromAsset(inputImageName), roi); imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
List<MPImage> segmentations = actualResult.confidenceMasks().get(); assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
assertThat(segmentations.size()).isEqualTo(2); List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
assertThat(confidenceMasks.size()).isEqualTo(2);
} }
} }

View File

@ -204,6 +204,11 @@ This can be useful for resetting a stateful task graph to process new data.
Raises: Raises:
RuntimeError: The underlying medipaipe graph fails to reset and restart. RuntimeError: The underlying medipaipe graph fails to reset and restart.
)doc"); )doc");
task_runner.def(
"get_graph_config",
[](TaskRunner* self) { return self->GetGraphConfig(); },
R"doc(Returns the canonicalized CalculatorGraphConfig of the underlying graph.)doc");
} }
} // namespace python } // namespace python

View File

@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite' _BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite' _REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
# Tolerance for embedding vector coordinate values. # Tolerance for embedding vector coordinate values.
_EPSILON = 1e-4 _EPSILON = 1e-4
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
16, 16,
(0.549632, 0.552879), (0.549632, 0.552879),
), ),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
) )
def test_embed(self, l2_normalize, quantize, model_name, model_file_type, def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
expected_similarity, expected_size, expected_first_values): expected_similarity, expected_size, expected_first_values):
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
16, 16,
(0.549632, 0.552879), (0.549632, 0.552879),
), ),
(
False,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_NAME,
0.851961,
100,
(1.422951, 1.404664),
),
(
True,
False,
_USE_MODEL_FILE,
ModelFileType.FILE_CONTENT,
0.851961,
100,
(0.127049, 0.125416),
),
) )
def test_embed_in_context(self, l2_normalize, quantize, model_name, def test_embed_in_context(self, l2_normalize, quantize, model_name,
model_file_type, expected_similarity, expected_size, model_file_type, expected_similarity, expected_size,
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
# TODO: The similarity should likely be lower # TODO: The similarity should likely be lower
(_BERT_MODEL_FILE, 0.980880), (_BERT_MODEL_FILE, 0.980880),
(_USE_MODEL_FILE, 0.780334),
) )
def test_embed_with_different_themes(self, model_file, expected_similarity): def test_embed_with_different_themes(self, model_file, expected_similarity):
# Creates embedder. # Creates embedder.

View File

@ -15,7 +15,6 @@
import enum import enum
import os import os
from typing import List
from unittest import mock from unittest import mock
from absl.testing import absltest from absl.testing import absltest
@ -30,11 +29,10 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision import image_segmenter
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = image_segmenter.ImageSegmenterResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_OutputType = image_segmenter.ImageSegmenterOptions.OutputType
_Activation = image_segmenter.ImageSegmenterOptions.Activation
_ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenter = image_segmenter.ImageSegmenter
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
@ -42,11 +40,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
_MODEL_FILE = 'deeplabv3.tflite' _MODEL_FILE = 'deeplabv3.tflite'
_IMAGE_FILE = 'segmentation_input_rotation0.jpg' _IMAGE_FILE = 'segmentation_input_rotation0.jpg'
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
_CAT_IMAGE = 'cat.jpg'
_CAT_MASK = 'cat_mask.jpg'
_MASK_MAGNIFICATION_FACTOR = 10 _MASK_MAGNIFICATION_FACTOR = 10
_MASK_SIMILARITY_THRESHOLD = 0.98 _MASK_SIMILARITY_THRESHOLD = 0.98
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
def _calculate_soft_iou(m1, m2):
intersection_sum = np.sum(m1 * m2)
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
if union_sum > 0:
return intersection_sum / union_sum
else:
return 0
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
actual_mask = actual_mask.numpy_view()
expected_mask = expected_mask.numpy_view() / 255.0
return (
actual_mask.shape == expected_mask.shape
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
)
def _similar_to_uint8_mask(actual_mask, expected_mask): def _similar_to_uint8_mask(actual_mask, expected_mask):
actual_mask_pixels = actual_mask.numpy_view().flatten() actual_mask_pixels = actual_mask.numpy_view().flatten()
expected_mask_pixels = expected_mask.numpy_view().flatten() expected_mask_pixels = expected_mask.numpy_view().flatten()
@ -56,8 +76,9 @@ def _similar_to_uint8_mask(actual_mask, expected_mask):
for index in range(num_pixels): for index in range(num_pixels):
consistent_pixels += ( consistent_pixels += (
actual_mask_pixels[index] * actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR
_MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index]) == expected_mask_pixels[index]
)
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
@ -73,16 +94,27 @@ class ImageSegmenterTest(parameterized.TestCase):
super().setUp() super().setUp()
# Load the test input image. # Load the test input image.
self.test_image = _Image.create_from_file( self.test_image = _Image.create_from_file(
test_utils.get_test_data_path( test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) )
# Loads ground truth segmentation file. # Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread( gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path( test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)), os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)
cv2.IMREAD_GRAYSCALE) ),
cv2.IMREAD_GRAYSCALE,
)
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data) self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
self.model_path = test_utils.get_test_data_path( self.model_path = test_utils.get_test_data_path(
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
)
def _load_segmentation_mask(self, file_path: str):
# Loads ground truth segmentation file.
gt_segmentation_data = cv2.imread(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
cv2.IMREAD_GRAYSCALE,
)
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
def test_create_from_file_succeeds_with_valid_model_path(self): def test_create_from_file_succeeds_with_valid_model_path(self):
# Creates with default option and valid model file successfully. # Creates with default option and valid model file successfully.
@ -98,9 +130,11 @@ class ImageSegmenterTest(parameterized.TestCase):
def test_create_from_options_fails_with_invalid_model_path(self): def test_create_from_options_fails_with_invalid_model_path(self):
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
):
base_options = _BaseOptions( base_options = _BaseOptions(
model_asset_path='/path/to/invalid/model.tflite') model_asset_path='/path/to/invalid/model.tflite'
)
options = _ImageSegmenterOptions(base_options=base_options) options = _ImageSegmenterOptions(base_options=base_options)
_ImageSegmenter.create_from_options(options) _ImageSegmenter.create_from_options(options)
@ -112,8 +146,9 @@ class ImageSegmenterTest(parameterized.TestCase):
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
self.assertIsInstance(segmenter, _ImageSegmenter) self.assertIsInstance(segmenter, _ImageSegmenter)
@parameterized.parameters((ModelFileType.FILE_NAME,), @parameterized.parameters(
(ModelFileType.FILE_CONTENT,)) (ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,)
)
def test_segment_succeeds_with_category_mask(self, model_file_type): def test_segment_succeeds_with_category_mask(self, model_file_type):
# Creates segmenter. # Creates segmenter.
if model_file_type is ModelFileType.FILE_NAME: if model_file_type is ModelFileType.FILE_NAME:
@ -127,22 +162,27 @@ class ImageSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK) base_options=base_options,
output_category_mask=True,
output_confidence_masks=False,
)
segmenter = _ImageSegmenter.create_from_options(options) segmenter = _ImageSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
category_masks = segmenter.segment(self.test_image) segmentation_result = segmenter.segment(self.test_image)
self.assertLen(category_masks, 1) category_mask = segmentation_result.category_mask
category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten() result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct. # Check if data type of `category_mask` is correct.
self.assertEqual(result_pixels.dtype, np.uint8) self.assertEqual(result_pixels.dtype, np.uint8)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' (
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') 'Number of pixels in the candidate mask differing from that of the'
f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
# Closes the segmenter explicitly when the segmenter is not used in # Closes the segmenter explicitly when the segmenter is not used in
# a context. # a context.
@ -152,74 +192,46 @@ class ImageSegmenterTest(parameterized.TestCase):
# Creates segmenter. # Creates segmenter.
base_options = _BaseOptions(model_asset_path=self.model_path) base_options = _BaseOptions(model_asset_path=self.model_path)
# Run segmentation on the model in CATEGORY_MASK mode. # Load the cat image.
options = _ImageSegmenterOptions( test_image = _Image.create_from_file(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK) test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
segmenter = _ImageSegmenter.create_from_options(options) )
category_masks = segmenter.segment(self.test_image)
category_mask = category_masks[0].numpy_view()
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=base_options, base_options=base_options,
output_type=_OutputType.CONFIDENCE_MASK, output_category_mask=False,
activation=_Activation.SOFTMAX) output_confidence_masks=True,
segmenter = _ImageSegmenter.create_from_options(options) )
confidence_masks = segmenter.segment(self.test_image)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks, 21,
'Number of confidence masks must match with number of categories.')
# Gather the confidence masks in a single array `confidence_mask_array`.
confidence_mask_array = np.array(
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
# Check if data type of `confidence_masks` are correct.
self.assertEqual(confidence_mask_array.dtype, np.float32)
# Compute the category mask from the created confidence mask.
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
self.assertListEqual(
calculated_category_mask.tolist(), category_mask.tolist(),
'Confidence mask does not match with the category mask.')
# Closes the segmenter explicitly when the segmenter is not used in
# a context.
segmenter.close()
@parameterized.parameters((ModelFileType.FILE_NAME),
(ModelFileType.FILE_CONTENT))
def test_segment_in_context(self, model_file_type):
if model_file_type is ModelFileType.FILE_NAME:
base_options = _BaseOptions(model_asset_path=self.model_path)
elif model_file_type is ModelFileType.FILE_CONTENT:
with open(self.model_path, 'rb') as f:
model_contents = f.read()
base_options = _BaseOptions(model_asset_buffer=model_contents)
else:
# Should never happen
raise ValueError('model_file_type is invalid.')
options = _ImageSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
# Performs image segmentation on the input. segmentation_result = segmenter.segment(test_image)
category_masks = segmenter.segment(self.test_image) confidence_masks = segmentation_result.confidence_masks
self.assertLen(category_masks, 1)
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_float_mask(
f'Number of pixels in the candidate mask differing from that of the ' confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') )
)
def test_missing_result_callback(self): def test_missing_result_callback(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM) running_mode=_RUNNING_MODE.LIVE_STREAM,
with self.assertRaisesRegex(ValueError, )
r'result callback must be provided'): with self.assertRaisesRegex(
ValueError, r'result callback must be provided'
):
with _ImageSegmenter.create_from_options(options) as unused_segmenter: with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass pass
@ -228,130 +240,236 @@ class ImageSegmenterTest(parameterized.TestCase):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=running_mode, running_mode=running_mode,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
with self.assertRaisesRegex(ValueError, )
r'result callback should not be provided'): with self.assertRaisesRegex(
ValueError, r'result callback should not be provided'
):
with _ImageSegmenter.create_from_options(options) as unused_segmenter: with _ImageSegmenter.create_from_options(options) as unused_segmenter:
pass pass
def test_calling_segment_for_video_in_image_mode(self): def test_calling_segment_for_video_in_image_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_calling_segment_async_in_image_mode(self): def test_calling_segment_async_in_image_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.IMAGE) running_mode=_RUNNING_MODE.IMAGE,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
segmenter.segment_async(self.test_image, 0) segmenter.segment_async(self.test_image, 0)
def test_calling_segment_in_video_mode(self): def test_calling_segment_in_video_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
segmenter.segment(self.test_image) segmenter.segment(self.test_image)
def test_calling_segment_async_in_video_mode(self): def test_calling_segment_async_in_video_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the live stream mode'): ValueError, r'not initialized with the live stream mode'
):
segmenter.segment_async(self.test_image, 0) segmenter.segment_async(self.test_image, 0)
def test_segment_for_video_with_out_of_order_timestamp(self): def test_segment_for_video_with_out_of_order_timestamp(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO) running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
unused_result = segmenter.segment_for_video(self.test_image, 1) unused_result = segmenter.segment_for_video(self.test_image, 1)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_segment_for_video(self): def test_segment_for_video_in_category_mask_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK, output_category_mask=True,
running_mode=_RUNNING_MODE.VIDEO) output_confidence_masks=False,
running_mode=_RUNNING_MODE.VIDEO,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
category_masks = segmenter.segment_for_video(self.test_image, timestamp) segmentation_result = segmenter.segment_for_video(
self.assertLen(category_masks, 1) self.test_image, timestamp
)
category_mask = segmentation_result.category_mask
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_masks[0], self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' (
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') 'Number of pixels in the candidate mask differing from that of'
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
def test_segment_for_video_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.VIDEO,
output_category_mask=False,
output_confidence_masks=True,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmentation_result = segmenter.segment_for_video(test_image, timestamp)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
def test_calling_segment_in_live_stream_mode(self): def test_calling_segment_in_live_stream_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the image mode'): ValueError, r'not initialized with the image mode'
):
segmenter.segment(self.test_image) segmenter.segment(self.test_image)
def test_calling_segment_for_video_in_live_stream_mode(self): def test_calling_segment_for_video_in_live_stream_mode(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(
r'not initialized with the video mode'): ValueError, r'not initialized with the video mode'
):
segmenter.segment_for_video(self.test_image, 0) segmenter.segment_for_video(self.test_image, 0)
def test_segment_async_calls_with_illegal_timestamp(self): def test_segment_async_calls_with_illegal_timestamp(self):
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=mock.MagicMock()) result_callback=mock.MagicMock(),
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
segmenter.segment_async(self.test_image, 100) segmenter.segment_async(self.test_image, 100)
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, r'Input timestamp must be monotonically increasing'): ValueError, r'Input timestamp must be monotonically increasing'
):
segmenter.segment_async(self.test_image, 0) segmenter.segment_async(self.test_image, 0)
def test_segment_async_calls(self): def test_segment_async_calls_in_category_mask_mode(self):
observed_timestamp_ms = -1 observed_timestamp_ms = -1
def check_result(result: List[image_module.Image], output_image: _Image, def check_result(
timestamp_ms: int): result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
):
# Get the output category mask. # Get the output category mask.
category_mask = result[0] category_mask = result.category_mask
self.assertEqual(output_image.width, self.test_image.width) self.assertEqual(output_image.width, self.test_image.width)
self.assertEqual(output_image.height, self.test_image.height) self.assertEqual(output_image.height, self.test_image.height)
self.assertEqual(output_image.width, self.test_seg_image.width) self.assertEqual(output_image.width, self.test_seg_image.width)
self.assertEqual(output_image.height, self.test_seg_image.height) self.assertEqual(output_image.height, self.test_seg_image.height)
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask(category_mask, self.test_seg_image), _similar_to_uint8_mask(category_mask, self.test_seg_image),
f'Number of pixels in the candidate mask differing from that of the ' (
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') 'Number of pixels in the candidate mask differing from that of'
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
),
)
self.assertLess(observed_timestamp_ms, timestamp_ms) self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions( options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path), base_options=_BaseOptions(model_asset_path=self.model_path),
output_type=_OutputType.CATEGORY_MASK, output_category_mask=True,
output_confidence_masks=False,
running_mode=_RUNNING_MODE.LIVE_STREAM, running_mode=_RUNNING_MODE.LIVE_STREAM,
result_callback=check_result) result_callback=check_result,
)
with _ImageSegmenter.create_from_options(options) as segmenter: with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30): for timestamp in range(0, 300, 30):
segmenter.segment_async(self.test_image, timestamp) segmenter.segment_async(self.test_image, timestamp)
def test_segment_async_calls_in_confidence_mask_mode(self):
# Load the cat image.
test_image = _Image.create_from_file(
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
)
# Loads ground truth segmentation file.
expected_mask = self._load_segmentation_mask(_CAT_MASK)
observed_timestamp_ms = -1
def check_result(
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
):
# Get the output category mask.
confidence_masks = result.confidence_masks
# Check if confidence mask shape is correct.
self.assertLen(
confidence_masks,
21,
'Number of confidence masks must match with number of categories.',
)
self.assertEqual(output_image.width, test_image.width)
self.assertEqual(output_image.height, test_image.height)
self.assertTrue(
_similar_to_float_mask(
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
)
)
self.assertLess(observed_timestamp_ms, timestamp_ms)
self.observed_timestamp_ms = timestamp_ms
options = _ImageSegmenterOptions(
base_options=_BaseOptions(model_asset_path=self.model_path),
running_mode=_RUNNING_MODE.LIVE_STREAM,
output_category_mask=False,
output_confidence_masks=True,
result_callback=check_result,
)
with _ImageSegmenter.create_from_options(options) as segmenter:
for timestamp in range(0, 300, 30):
segmenter.segment_async(test_image, timestamp)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils
from mediapipe.tasks.python.vision import interactive_segmenter from mediapipe.tasks.python.vision import interactive_segmenter
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_Image = image_module.Image _Image = image_module.Image
_ImageFormat = image_frame.ImageFormat _ImageFormat = image_frame.ImageFormat
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint _NormalizedKeypoint = keypoint_module.NormalizedKeypoint
_Rect = rect.Rect _Rect = rect.Rect
_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
_RegionOfInterest = interactive_segmenter.RegionOfInterest _RegionOfInterest = interactive_segmenter.RegionOfInterest
@ -200,15 +200,16 @@ class InteractiveSegmenterTest(parameterized.TestCase):
raise ValueError('model_file_type is invalid.') raise ValueError('model_file_type is invalid.')
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CATEGORY_MASK base_options=base_options,
output_category_mask=True,
output_confidence_masks=False,
) )
segmenter = _InteractiveSegmenter.create_from_options(options) segmenter = _InteractiveSegmenter.create_from_options(options)
# Performs image segmentation on the input. # Performs image segmentation on the input.
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
category_masks = segmenter.segment(self.test_image, roi) segmentation_result = segmenter.segment(self.test_image, roi)
self.assertLen(category_masks, 1) category_mask = segmentation_result.category_mask
category_mask = category_masks[0]
result_pixels = category_mask.numpy_view().flatten() result_pixels = category_mask.numpy_view().flatten()
# Check if data type of `category_mask` is correct. # Check if data type of `category_mask` is correct.
@ -219,7 +220,7 @@ class InteractiveSegmenterTest(parameterized.TestCase):
self.assertTrue( self.assertTrue(
_similar_to_uint8_mask( _similar_to_uint8_mask(
category_masks[0], test_seg_image, similarity_threshold category_mask, test_seg_image, similarity_threshold
), ),
( (
'Number of pixels in the candidate mask differing from that of the' 'Number of pixels in the candidate mask differing from that of the'
@ -254,12 +255,15 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
) )
with _InteractiveSegmenter.create_from_options(options) as segmenter: with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation # Perform segmentation
confidence_masks = segmenter.segment(self.test_image, roi) segmentation_result = segmenter.segment(self.test_image, roi)
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct. # Check if confidence mask shape is correct.
self.assertLen( self.assertLen(
@ -287,15 +291,18 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
) )
with _InteractiveSegmenter.create_from_options(options) as segmenter: with _InteractiveSegmenter.create_from_options(options) as segmenter:
# Perform segmentation # Perform segmentation
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
confidence_masks = segmenter.segment( segmentation_result = segmenter.segment(
self.test_image, roi, image_processing_options self.test_image, roi, image_processing_options
) )
confidence_masks = segmentation_result.confidence_masks
# Check if confidence mask shape is correct. # Check if confidence mask shape is correct.
self.assertLen( self.assertLen(
@ -314,7 +321,9 @@ class InteractiveSegmenterTest(parameterized.TestCase):
# Run segmentation on the model in CONFIDENCE_MASK mode. # Run segmentation on the model in CONFIDENCE_MASK mode.
options = _InteractiveSegmenterOptions( options = _InteractiveSegmenterOptions(
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK base_options=base_options,
output_category_mask=False,
output_confidence_masks=True,
) )
with self.assertRaisesRegex( with self.assertRaisesRegex(

View File

@ -32,6 +32,7 @@ FaceDetectorResult = face_detector.FaceDetectorResult
FaceLandmarker = face_landmarker.FaceLandmarker FaceLandmarker = face_landmarker.FaceLandmarker
FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions
FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult
FaceLandmarksConnections = face_landmarker.FaceLandmarksConnections
FaceStylizer = face_stylizer.FaceStylizer FaceStylizer = face_stylizer.FaceStylizer
FaceStylizerOptions = face_stylizer.FaceStylizerOptions FaceStylizerOptions = face_stylizer.FaceStylizerOptions
GestureRecognizer = gesture_recognizer.GestureRecognizer GestureRecognizer = gesture_recognizer.GestureRecognizer

View File

@ -208,6 +208,11 @@ class BaseVisionTaskApi(object):
""" """
self._runner.close() self._runner.close()
def get_graph_config(self) -> calculator_pb2.CalculatorGraphConfig:
"""Returns the canonicalized CalculatorGraphConfig of the underlying graph.
"""
return self._runner.get_graph_config()
def __enter__(self): def __enter__(self):
"""Return `self` upon entering the runtime context.""" """Return `self` upon entering the runtime context."""
return self return self

File diff suppressed because it is too large Load Diff

View File

@ -176,16 +176,13 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
Only use this method when the FaceStylizer is created with the image Only use this method when the FaceStylizer is created with the image
running mode. running mode.
To ensure that the output image has reasonable quality, the stylized output
image size is the smaller of the model output size and the size of the
`region_of_interest` specified in `image_processing_options`.
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
image_processing_options: Options for image processing. image_processing_options: Options for image processing.
Returns: Returns:
The stylized image of the most visible face. None if no face is detected The stylized image of the most visible face. The stylized output image
size is the same as the model output size. None if no face is detected
on the input image. on the input image.
Raises: Raises:
@ -217,17 +214,14 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
milliseconds) along with the video frame. The input timestamps should be milliseconds) along with the video frame. The input timestamps should be
monotonically increasing for adjacent calls of this method. monotonically increasing for adjacent calls of this method.
To ensure that the output image has reasonable quality, the stylized output
image size is the smaller of the model output size and the size of the
`region_of_interest` specified in `image_processing_options`.
Args: Args:
image: MediaPipe Image. image: MediaPipe Image.
timestamp_ms: The timestamp of the input video frame in milliseconds. timestamp_ms: The timestamp of the input video frame in milliseconds.
image_processing_options: Options for image processing. image_processing_options: Options for image processing.
Returns: Returns:
The stylized image of the most visible face. None if no face is detected The stylized image of the most visible face. The stylized output image
size is the same as the model output size. None if no face is detected
on the input image. on the input image.
Raises: Raises:
@ -266,12 +260,9 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
images if needed. In other words, it's not guaranteed to have output per images if needed. In other words, it's not guaranteed to have output per
input image. input image.
To ensure that the stylized image has reasonable quality, the stylized
output image size is the smaller of the model output size and the size of
the `region_of_interest` specified in `image_processing_options`.
The `result_callback` provides: The `result_callback` provides:
- The stylized image of the most visible face. None if no face is detected - The stylized image of the most visible face. The stylized output image
size is the same as the model output size. None if no face is detected
on the input image. on the input image.
- The input image that the face stylizer runs on. - The input image that the face stylizer runs on.
- The input timestamp in milliseconds. - The input timestamp in milliseconds.

View File

@ -14,7 +14,6 @@
"""MediaPipe image segmenter task.""" """MediaPipe image segmenter task."""
import dataclasses import dataclasses
import enum
from typing import Callable, List, Mapping, Optional from typing import Callable, List, Mapping, Optional
from mediapipe.python import packet_creator from mediapipe.python import packet_creator
@ -31,7 +30,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
from mediapipe.tasks.python.vision.core import vision_task_running_mode from mediapipe.tasks.python.vision.core import vision_task_running_mode
ImageSegmenterResult = List[image_module.Image]
_NormalizedRect = rect.NormalizedRect _NormalizedRect = rect.NormalizedRect
_BaseOptions = base_options_module.BaseOptions _BaseOptions = base_options_module.BaseOptions
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
@ -42,8 +40,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_IMAGE_TAG = 'IMAGE' _IMAGE_TAG = 'IMAGE'
@ -53,6 +53,21 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
_MICRO_SECONDS_PER_MILLISECOND = 1000 _MICRO_SECONDS_PER_MILLISECOND = 1000
@dataclasses.dataclass
class ImageSegmenterResult:
"""Output result of ImageSegmenter.
confidence_masks: multiple masks of float image where, for each mask, each
pixel represents the prediction confidence, usually in the [0, 1] range.
category_mask: a category mask of uint8 image where each pixel represents the
class which the pixel in the original image was predicted to belong to.
"""
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass @dataclasses.dataclass
class ImageSegmenterOptions: class ImageSegmenterOptions:
"""Options for the image segmenter task. """Options for the image segmenter task.
@ -64,28 +79,17 @@ class ImageSegmenterOptions:
objects on single image inputs. 2) The video mode for segmenting objects objects on single image inputs. 2) The video mode for segmenting objects
on the decoded frames of a video. 3) The live stream mode for segmenting on the decoded frames of a video. 3) The live stream mode for segmenting
objects on a live stream of input data, such as from camera. objects on a live stream of input data, such as from camera.
output_type: The output mask type allows specifying the type of output_confidence_masks: Whether to output confidence masks.
post-processing to perform on the raw model results. output_category_mask: Whether to output category mask.
activation: Activation function to apply to input tensor.
result_callback: The user-defined result callback for processing live stream result_callback: The user-defined result callback for processing live stream
data. The result callback should only be specified when the running mode data. The result callback should only be specified when the running mode
is set to the live stream mode. is set to the live stream mode.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
class Activation(enum.Enum):
NONE = 0
SIGMOID = 1
SOFTMAX = 2
base_options: _BaseOptions base_options: _BaseOptions
running_mode: _RunningMode = _RunningMode.IMAGE running_mode: _RunningMode = _RunningMode.IMAGE
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_confidence_masks: bool = True
activation: Optional[Activation] = Activation.NONE output_category_mask: bool = False
result_callback: Optional[ result_callback: Optional[
Callable[[ImageSegmenterResult, image_module.Image, int], None] Callable[[ImageSegmenterResult, image_module.Image, int], None]
] = None ] = None
@ -97,9 +101,7 @@ class ImageSegmenterOptions:
base_options_proto.use_stream_mode = ( base_options_proto.use_stream_mode = (
False if self.running_mode == _RunningMode.IMAGE else True False if self.running_mode == _RunningMode.IMAGE else True
) )
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto()
output_type=self.output_type.value, activation=self.activation.value
)
return _ImageSegmenterGraphOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto, segmenter_options=segmenter_options_proto,
@ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
def packets_callback(output_packets: Mapping[str, packet.Packet]): def packets_callback(output_packets: Mapping[str, packet.Packet]):
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
return return
segmentation_result = packet_getter.get_image_list(
output_packets[_SEGMENTATION_OUT_STREAM_NAME] segmentation_result = ImageSegmenterResult()
)
if options.output_confidence_masks:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if options.output_category_mask:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
options.result_callback( options.result_callback(
segmentation_result, segmentation_result,
image, image,
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
) )
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=output_streams,
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options, task_options=options,
) )
return cls( return cls(
@ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = ImageSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_for_video( def segment_for_video(
@ -285,9 +317,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = ImageSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result
def segment_async( def segment_async(

View File

@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
_TaskInfo = task_info_module.TaskInfo _TaskInfo = task_info_module.TaskInfo
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' _CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' _CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
_IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_IN_STREAM_NAME = 'image_in'
_IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_OUT_STREAM_NAME = 'image_out'
_ROI_STREAM_NAME = 'roi_in' _ROI_STREAM_NAME = 'roi_in'
@ -55,32 +57,41 @@ _TASK_GRAPH_NAME = (
) )
@dataclasses.dataclass
class InteractiveSegmenterResult:
"""Output result of InteractiveSegmenter.
confidence_masks: multiple masks of float image where, for each mask, each
pixel represents the prediction confidence, usually in the [0, 1] range.
category_mask: a category mask of uint8 image where each pixel represents the
class which the pixel in the original image was predicted to belong to.
"""
confidence_masks: Optional[List[image_module.Image]] = None
category_mask: Optional[image_module.Image] = None
@dataclasses.dataclass @dataclasses.dataclass
class InteractiveSegmenterOptions: class InteractiveSegmenterOptions:
"""Options for the interactive segmenter task. """Options for the interactive segmenter task.
Attributes: Attributes:
base_options: Base options for the interactive segmenter task. base_options: Base options for the interactive segmenter task.
output_type: The output mask type allows specifying the type of output_confidence_masks: Whether to output confidence masks.
post-processing to perform on the raw model results. output_category_mask: Whether to output category mask.
""" """
class OutputType(enum.Enum):
UNSPECIFIED = 0
CATEGORY_MASK = 1
CONFIDENCE_MASK = 2
base_options: _BaseOptions base_options: _BaseOptions
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK output_confidence_masks: bool = True
output_category_mask: bool = False
@doc_controls.do_not_generate_docs @doc_controls.do_not_generate_docs
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
"""Generates an InteractiveSegmenterOptions protobuf object.""" """Generates an InteractiveSegmenterOptions protobuf object."""
base_options_proto = self.base_options.to_pb2() base_options_proto = self.base_options.to_pb2()
base_options_proto.use_stream_mode = False base_options_proto.use_stream_mode = False
segmenter_options_proto = _SegmenterOptionsProto( segmenter_options_proto = _SegmenterOptionsProto()
output_type=self.output_type.value
)
return _ImageSegmenterGraphOptionsProto( return _ImageSegmenterGraphOptionsProto(
base_options=base_options_proto, base_options=base_options_proto,
segmenter_options=segmenter_options_proto, segmenter_options=segmenter_options_proto,
@ -192,6 +203,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
RuntimeError: If other types of error occurred. RuntimeError: If other types of error occurred.
""" """
output_streams = [
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
]
if options.output_confidence_masks:
output_streams.append(
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
)
if options.output_category_mask:
output_streams.append(
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
)
task_info = _TaskInfo( task_info = _TaskInfo(
task_graph=_TASK_GRAPH_NAME, task_graph=_TASK_GRAPH_NAME,
input_streams=[ input_streams=[
@ -199,10 +224,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
':'.join([_ROI_TAG, _ROI_STREAM_NAME]), ':'.join([_ROI_TAG, _ROI_STREAM_NAME]),
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
], ],
output_streams=[ output_streams=output_streams,
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
],
task_options=options, task_options=options,
) )
return cls( return cls(
@ -216,7 +238,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
image: image_module.Image, image: image_module.Image,
roi: RegionOfInterest, roi: RegionOfInterest,
image_processing_options: Optional[_ImageProcessingOptions] = None, image_processing_options: Optional[_ImageProcessingOptions] = None,
) -> List[image_module.Image]: ) -> InteractiveSegmenterResult:
"""Performs the actual segmentation task on the provided MediaPipe Image. """Performs the actual segmentation task on the provided MediaPipe Image.
The image can be of any size with format RGB. The image can be of any size with format RGB.
@ -248,7 +270,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
normalized_rect.to_pb2() normalized_rect.to_pb2()
), ),
}) })
segmentation_result = packet_getter.get_image_list( segmentation_result = InteractiveSegmenterResult()
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
) if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
segmentation_result.confidence_masks = packet_getter.get_image_list(
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
)
if _CATEGORY_MASK_STREAM_NAME in output_packets:
segmentation_result.category_mask = packet_getter.get_image(
output_packets[_CATEGORY_MASK_STREAM_NAME]
)
return segmentation_result return segmentation_result

View File

@ -59,13 +59,12 @@ export function drawCategoryMask(
const isFloatArray = image instanceof Float32Array; const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) { for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex]; let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
// When we're given a confidence mask by accident, we just log and return.
// TODO: We should fix this.
if (!color) { if (!color) {
// TODO: We should fix this.
console.warn('No color for ', colorIndex); console.warn('No color for ', colorIndex);
return; color = COLOR_MAP[colorIndex % COLOR_MAP.length];
} }
rgbaArray[4 * i] = color[0]; rgbaArray[4 * i] = color[0];

View File

@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke
*/ */
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture; export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
/**
* A callback that receives the computed masks from the segmentation tasks. The
* callback either receives a single element array with a category mask (as a
* `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`).
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type SegmentationMaskCallback =
(masks: SegmentationMask[], width: number, height: number) => void;
/** /**
* A callback that receives an `ImageData` object from a Vision task. The * A callback that receives an `ImageData` object from a Vision task. The
* lifetime of the underlying data is limited to the duration of the callback. * lifetime of the underlying data is limited to the duration of the callback.

View File

@ -19,7 +19,7 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
// tslint:disable:class-as-namespace Using for easier import by 3P users // tslint:disable:class-as-namespace Using for easier import by 3P users
/** /**
* A class containing the Pairs of landmark indices to be rendered with * A class containing the pairs of landmark indices to be rendered with
* connections. * connections.
*/ */
export class FaceLandmarksConnections { export class FaceLandmarksConnections {

View File

@ -129,10 +129,6 @@ export class FaceStylizer extends VisionTaskRunner {
* synchronously once the callback returns. Only use this method when the * synchronously once the callback returns. Only use this method when the
* FaceStylizer is created with the image running mode. * FaceStylizer is created with the image running mode.
* *
* The input image can be of any size. To ensure that the output image has
* reasonable quality, the stylized output image size is determined by the
* model output size.
*
* @param image An image to process. * @param image An image to process.
* @param callback The callback that is invoked with the stylized image. The * @param callback The callback that is invoked with the stylized image. The
* lifetime of the returned data is only guaranteed for the duration of the * lifetime of the returned data is only guaranteed for the duration of the
@ -153,11 +149,6 @@ export class FaceStylizer extends VisionTaskRunner {
* If both are specified, the crop around the region-of-interest is extracted * If both are specified, the crop around the region-of-interest is extracted
* first, then the specified rotation is applied to the crop. * first, then the specified rotation is applied to the crop.
* *
* The input image can be of any size. To ensure that the output image has
* reasonable quality, the stylized output image size is the smaller of the
* model output size and the size of the 'regionOfInterest' specified in
* 'imageProcessingOptions'.
*
* @param image An image to process. * @param image An image to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference. * to process the input image before running inference.
@ -192,9 +183,6 @@ export class FaceStylizer extends VisionTaskRunner {
* frame's timestamp (in milliseconds). The input timestamps must be * frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing. * monotonically increasing.
* *
* To ensure that the output image has reasonable quality, the stylized
* output image size is determined by the model output size.
*
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms. * @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the stylized image. The * @param callback The callback that is invoked with the stylized image. The
@ -221,10 +209,6 @@ export class FaceStylizer extends VisionTaskRunner {
* frame's timestamp (in milliseconds). The input timestamps must be * frame's timestamp (in milliseconds). The input timestamps must be
* monotonically increasing. * monotonically increasing.
* *
* To ensure that the output image has reasonable quality, the stylized
* output image size is the smaller of the model output size and the size of
* the 'regionOfInterest' specified in 'imageProcessingOptions'.
*
* @param videoFrame A video frame to process. * @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how * @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference. * to process the input image before running inference.
@ -278,8 +262,12 @@ export class FaceStylizer extends VisionTaskRunner {
this.graphRunner.attachImageListener( this.graphRunner.attachImageListener(
STYLIZED_IMAGE_STREAM, (image, timestamp) => { STYLIZED_IMAGE_STREAM, (image, timestamp) => {
const imageData = this.convertToImageData(image); if (image.data instanceof WebGLTexture) {
this.userCallback(imageData, image.width, image.height); this.userCallback(image.data, image.width, image.height);
} else {
const imageData = this.convertToImageData(image);
this.userCallback(imageData, image.width, image.height);
}
this.setLatestOutputTimestamp(timestamp); this.setLatestOutputTimestamp(timestamp);
}); });
this.graphRunner.attachEmptyPacketListener( this.graphRunner.attachEmptyPacketListener(

View File

@ -34,6 +34,7 @@ mediapipe_ts_library(
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",
"//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/tasks/web/vision/hand_landmarker:hand_landmarks_connections",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )

View File

@ -31,6 +31,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
@ -72,6 +73,12 @@ export class GestureRecognizer extends VisionTaskRunner {
private readonly handGestureRecognizerGraphOptions: private readonly handGestureRecognizerGraphOptions:
HandGestureRecognizerGraphOptions; HandGestureRecognizerGraphOptions;
/**
* An array containing the pairs of hand landmark indices to be rendered with
* connections.
*/
static HAND_CONNECTIONS = HAND_CONNECTIONS;
/** /**
* Initializes the Wasm runtime and creates a new gesture recognizer from the * Initializes the Wasm runtime and creates a new gesture recognizer from the
* provided options. * provided options.

View File

@ -16,6 +16,7 @@ mediapipe_ts_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":hand_landmarker_types", ":hand_landmarker_types",
":hand_landmarks_connections",
"//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_jspb_proto",
"//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto",
"//mediapipe/framework/formats:classification_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto",
@ -72,3 +73,9 @@ jasmine_node_test(
tags = ["nomsan"], tags = ["nomsan"],
deps = [":hand_landmarker_test_lib"], deps = [":hand_landmarker_test_lib"],
) )
mediapipe_ts_library(
name = "hand_landmarks_connections",
srcs = ["hand_landmarks_connections.ts"],
deps = ["//mediapipe/tasks/web/vision/core:types"],
)

View File

@ -27,6 +27,7 @@ import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/con
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
@ -63,6 +64,12 @@ export class HandLandmarker extends VisionTaskRunner {
HandLandmarksDetectorGraphOptions; HandLandmarksDetectorGraphOptions;
private readonly handDetectorGraphOptions: HandDetectorGraphOptions; private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
/**
* An array containing the pairs of hand landmark indices to be rendered with
* connections.
*/
static HAND_CONNECTIONS = HAND_CONNECTIONS;
/** /**
* Initializes the Wasm runtime and creates a new `HandLandmarker` from the * Initializes the Wasm runtime and creates a new `HandLandmarker` from the
* provided options. * provided options.

View File

@ -0,0 +1,31 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {Connection} from '../../../../tasks/web/vision/core/types';
/**
* An array containing the pairs of hand landmark indices to be rendered with
* connections.
*/
export const HAND_CONNECTIONS: Connection[] = [
{start: 0, end: 1}, {start: 1, end: 2}, {start: 2, end: 3},
{start: 3, end: 4}, {start: 0, end: 5}, {start: 5, end: 6},
{start: 6, end: 7}, {start: 7, end: 8}, {start: 5, end: 9},
{start: 9, end: 10}, {start: 10, end: 11}, {start: 11, end: 12},
{start: 9, end: 13}, {start: 13, end: 14}, {start: 14, end: 15},
{start: 15, end: 16}, {start: 13, end: 17}, {start: 0, end: 17},
{start: 17, end: 18}, {start: 18, end: 19}, {start: 19, end: 20}
];

View File

@ -29,7 +29,10 @@ mediapipe_ts_library(
mediapipe_ts_declaration( mediapipe_ts_declaration(
name = "image_segmenter_types", name = "image_segmenter_types",
srcs = ["image_segmenter_options.d.ts"], srcs = [
"image_segmenter_options.d.ts",
"image_segmenter_result.d.ts",
],
deps = [ deps = [
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",

View File

@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {LabelMapItem} from '../../../../util/label_map_pb'; import {LabelMapItem} from '../../../../util/label_map_pb';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {ImageSegmenterOptions} from './image_segmenter_options'; import {ImageSegmenterOptions} from './image_segmenter_options';
import {ImageSegmenterResult} from './image_segmenter_result';
export * from './image_segmenter_options'; export * from './image_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback}; export * from './image_segmenter_result';
export {SegmentationMask};
export {ImageSource}; // Used in the public API export {ImageSource}; // Used in the public API
const IMAGE_STREAM = 'image_in'; const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect'; const NORM_RECT_STREAM = 'norm_rect';
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const IMAGE_SEGMENTER_GRAPH = const IMAGE_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
'mediapipe.tasks.TensorsToSegmentationCalculator'; 'mediapipe.tasks.TensorsToSegmentationCalculator';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/**
* A callback that receives the computed masks from the image segmenter. The
* returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void;
/** Performs image segmentation on images. */ /** Performs image segmentation on images. */
export class ImageSegmenter extends VisionTaskRunner { export class ImageSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {}; private result: ImageSegmenterResult = {width: 0, height: 0};
private labels: string[] = []; private labels: string[] = [];
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private readonly options: ImageSegmenterGraphOptionsProto; private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto;
@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner {
this.options.setBaseOptions(new BaseOptionsProto()); this.options.setBaseOptions(new BaseOptionsProto());
} }
protected override get baseOptions(): BaseOptionsProto { protected override get baseOptions(): BaseOptionsProto {
return this.options.getBaseOptions()!; return this.options.getBaseOptions()!;
} }
@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner {
this.options.clearDisplayNamesLocale(); this.options.clearDisplayNamesLocale();
} }
if (options.outputType === 'CONFIDENCE_MASK') { if ('outputCategoryMask' in options) {
this.segmenterOptions.setOutputType( this.outputCategoryMask =
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
} else { }
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CATEGORY_MASK); if ('outputConfidenceMasks' in options) {
this.outputConfidenceMasks =
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
} }
return super.applyOptions(options); return super.applyOptions(options);
@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner {
* lifetime of the returned data is only guaranteed for the duration of the * lifetime of the returned data is only guaranteed for the duration of the
* callback. * callback.
*/ */
segment(image: ImageSource, callback: SegmentationMaskCallback): void; segment(image: ImageSource, callback: ImageSegmenterCallack): void;
/** /**
* Performs image segmentation on the provided single image and invokes the * Performs image segmentation on the provided single image and invokes the
* callback with the response. The method returns synchronously once the * callback with the response. The method returns synchronously once the
@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner {
*/ */
segment( segment(
image: ImageSource, imageProcessingOptions: ImageProcessingOptions, image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
callback: SegmentationMaskCallback): void; callback: ImageSegmenterCallack): void;
segment( segment(
image: ImageSource, image: ImageSource,
imageProcessingOptionsOrCallback: ImageProcessingOptions| imageProcessingOptionsOrCallback: ImageProcessingOptions|
SegmentationMaskCallback, ImageSegmenterCallack,
callback?: SegmentationMaskCallback): void { callback?: ImageSegmenterCallack): void {
const imageProcessingOptions = const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
{}; {};
const userCallback =
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback!;
this.reset();
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {}; userCallback(this.result);
}
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, timestamp: number,
callback: ImageSegmenterCallack): void;
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: ImageSegmenterCallack): void;
segmentForVideo(
videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|ImageSegmenterCallack,
callback?: ImageSegmenterCallack): void {
const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions :
{};
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
const userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.reset();
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
userCallback(this.result);
} }
/** /**
@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner {
return this.labels; return this.labels;
} }
/** private reset(): void {
* Performs image segmentation on the provided video frame and invokes the this.result = {width: 0, height: 0};
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, timestamp: number,
callback: SegmentationMaskCallback): void;
/**
* Performs image segmentation on the provided video frame and invokes the
* callback with the response. The method returns synchronously once the
* callback returns. Only use this method when the ImageSegmenter is
* created with running mode `video`.
*
* @param videoFrame A video frame to process.
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
* to process the input image before running inference.
* @param timestamp The timestamp of the current frame, in ms.
* @param callback The callback that is invoked with the segmented masks. The
* lifetime of the returned data is only guaranteed for the duration of the
* callback.
*/
segmentForVideo(
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
timestamp: number, callback: SegmentationMaskCallback): void;
segmentForVideo(
videoFrame: ImageSource,
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
timestampOrCallback: number|SegmentationMaskCallback,
callback?: SegmentationMaskCallback): void {
const imageProcessingOptions =
typeof timestampOrImageProcessingOptions !== 'number' ?
timestampOrImageProcessingOptions :
{};
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
timestampOrImageProcessingOptions :
timestampOrCallback as number;
this.userCallback = typeof timestampOrCallback === 'function' ?
timestampOrCallback :
callback!;
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
this.userCallback = () => {};
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner {
const graphConfig = new CalculatorGraphConfig(); const graphConfig = new CalculatorGraphConfig();
graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(IMAGE_STREAM);
graphConfig.addInputStream(NORM_RECT_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM);
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner {
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH); segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
segmenterNode.addOutputStream(
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
segmenterNode.setOptions(calculatorOptions); segmenterNode.setOptions(calculatorOptions);
graphConfig.addNode(segmenterNode); graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageVectorListener( if (this.outputConfidenceMasks) {
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
if (masks.length === 0) { segmenterNode.addOutputStream(
this.userCallback([], 0, 0); 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
} else {
this.userCallback( this.graphRunner.attachImageVectorListener(
masks.map(m => m.data), masks[0].width, masks[0].height); CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
} this.result.confidenceMasks = masks.map(m => m.data);
this.setLatestOutputTimestamp(timestamp); if (masks.length >= 0) {
}); this.result.width = masks[0].width;
this.graphRunner.attachEmptyPacketListener( this.result.height = masks[0].height;
GROUPED_SEGMENTATIONS_STREAM, timestamp => { }
this.setLatestOutputTimestamp(timestamp);
}); this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
if (this.outputCategoryMask) {
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = mask.data;
this.result.width = mask.width;
this.result.height = mask.height;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions {
*/ */
displayNamesLocale?: string|undefined; displayNamesLocale?: string|undefined;
/** /** Whether to output confidence masks. Defaults to true. */
* The output type of segmentation results. outputConfidenceMasks?: boolean|undefined;
*
* The two supported modes are: /** Whether to output the category masks. Defaults to false. */
* - Category Mask: Gives a single output mask where each pixel represents outputCategoryMask?: boolean|undefined;
* the class which the pixel in the original image was
* predicted to belong to.
* - Confidence Mask: Gives a list of output masks (one for each class). For
* each mask, the pixel represents the prediction
* confidence, usually in the [0.0, 0.1] range.
*
* Defaults to `CATEGORY_MASK`.
*/
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
} }

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** The output result of ImageSegmenter. */
export declare interface ImageSegmenterResult {
/**
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
* pixel represents the prediction confidence, usually in the [0, 1] range.
*/
confidenceMasks?: Float32Array[]|WebGLTexture[];
/**
* A category mask as a Uint8ClampedArray or WebGLTexture where each
* pixel represents the class which the pixel in the original image was
* predicted to belong to.
*/
categoryMask?: Uint8ClampedArray|WebGLTexture;
/** The width of the masks. */
width: number;
/** The height of the masks. */
height: number;
}

View File

@ -18,7 +18,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
import {ImageSegmenter} from './image_segmenter'; import {ImageSegmenter} from './image_segmenter';
@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
graph: CalculatorGraphConfig|undefined; graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule; fakeWasmModule: SpyWasmModule;
imageVectorListener: categoryMaskListener:
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined; ((images: WasmImage[], timestamp: number) => void)|undefined;
constructor() { constructor() {
@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
this.fakeWasmModule = this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule; this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] = this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('category_mask');
this.categoryMaskListener = listener;
});
this.attachListenerSpies[1] =
spyOn(this.graphRunner, 'attachImageVectorListener') spyOn(this.graphRunner, 'attachImageVectorListener')
.and.callFake((stream, listener) => { .and.callFake((stream, listener) => {
expect(stream).toEqual('segmented_masks'); expect(stream).toEqual('confidence_masks');
this.imageVectorListener = listener; this.confidenceMasksListener = listener;
}); });
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
@ -63,17 +70,18 @@ describe('ImageSegmenter', () => {
it('initializes graph', async () => { it('initializes graph', async () => {
verifyGraph(imageSegmenter); verifyGraph(imageSegmenter);
verifyListenersRegistered(imageSegmenter);
// Verify default options
expect(imageSegmenter.categoryMaskListener).not.toBeDefined();
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
}); });
it('reloads graph when settings are changed', async () => { it('reloads graph when settings are changed', async () => {
await imageSegmenter.setOptions({displayNamesLocale: 'en'}); await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
verifyListenersRegistered(imageSegmenter);
await imageSegmenter.setOptions({displayNamesLocale: 'de'}); await imageSegmenter.setOptions({displayNamesLocale: 'de'});
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
verifyListenersRegistered(imageSegmenter);
}); });
it('can use custom models', async () => { it('can use custom models', async () => {
@ -100,9 +108,11 @@ describe('ImageSegmenter', () => {
}); });
it('merges options', async () => { it('merges options', async () => {
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); await imageSegmenter.setOptions(
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
await imageSegmenter.setOptions({displayNamesLocale: 'en'}); await imageSegmenter.setOptions({displayNamesLocale: 'en'});
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); verifyGraph(
imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']);
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
}); });
@ -115,22 +125,13 @@ describe('ImageSegmenter', () => {
defaultValue: unknown; defaultValue: unknown;
} }
const testCases: TestCase[] = [ const testCases: TestCase[] = [{
{ optionName: 'displayNamesLocale',
optionName: 'displayNamesLocale', fieldPath: ['displayNamesLocale'],
fieldPath: ['displayNamesLocale'], userValue: 'en',
userValue: 'en', graphValue: 'en',
graphValue: 'en', defaultValue: 'en'
defaultValue: 'en' }];
},
{
optionName: 'outputType',
fieldPath: ['segmenterOptions', 'outputType'],
userValue: 'CONFIDENCE_MASK',
graphValue: 2,
defaultValue: 1
},
];
for (const testCase of testCases) { for (const testCase of testCases) {
it(`can set ${testCase.optionName}`, async () => { it(`can set ${testCase.optionName}`, async () => {
@ -158,27 +159,31 @@ describe('ImageSegmenter', () => {
}).toThrowError('This task doesn\'t support region-of-interest.'); }).toThrowError('This task doesn\'t support region-of-interest.');
}); });
it('supports category masks', (done) => { it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]); const mask = new Uint8ClampedArray([1, 2, 3, 4]);
await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false});
// Pass the test data to our listener // Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter); expect(imageSegmenter.categoryMaskListener).toBeDefined();
imageSegmenter.imageVectorListener!( imageSegmenter.categoryMaskListener!
[ ({data: mask, width: 2, height: 2},
{data: mask, width: 2, height: 2}, /* timestamp= */ 1337);
],
/* timestamp= */ 1337);
}); });
// Invoke the image segmenter // Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); return new Promise<void>(resolve => {
expect(masks).toHaveSize(1); imageSegmenter.segment({} as HTMLImageElement, result => {
expect(masks[0]).toEqual(mask); expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(width).toEqual(2); expect(result.categoryMask).toEqual(mask);
expect(height).toEqual(2); expect(result.confidenceMasks).not.toBeDefined();
done(); expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve();
});
}); });
}); });
@ -186,12 +191,13 @@ describe('ImageSegmenter', () => {
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); await imageSegmenter.setOptions(
{outputCategoryMask: false, outputConfidenceMasks: true});
// Pass the test data to our listener // Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(imageSegmenter); expect(imageSegmenter.confidenceMasksListener).toBeDefined();
imageSegmenter.imageVectorListener!( imageSegmenter.confidenceMasksListener!(
[ [
{data: mask1, width: 2, height: 2}, {data: mask1, width: 2, height: 2},
{data: mask2, width: 2, height: 2}, {data: mask2, width: 2, height: 2},
@ -201,13 +207,49 @@ describe('ImageSegmenter', () => {
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
// Invoke the image segmenter // Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(masks).toHaveSize(2); expect(result.categoryMask).not.toBeDefined();
expect(masks[0]).toEqual(mask1); expect(result.confidenceMasks).toEqual([mask1, mask2]);
expect(masks[1]).toEqual(mask2); expect(result.width).toEqual(2);
expect(width).toEqual(2); expect(result.height).toEqual(2);
expect(height).toEqual(2); resolve();
});
});
});
it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1, 0]);
const confidenceMask1 = new Float32Array([0.0, 1.0]);
const confidenceMask2 = new Float32Array([1.0, 0.0]);
await imageSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: true});
// Pass the test data to our listener
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(imageSegmenter.categoryMaskListener).toBeDefined();
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
imageSegmenter.categoryMaskListener!
({data: categoryMask, width: 1, height: 1}, 1337);
imageSegmenter.confidenceMasksListener!(
[
{data: confidenceMask1, width: 1, height: 1},
{data: confidenceMask2, width: 1, height: 1},
],
1337);
});
return new Promise<void>(resolve => {
// Invoke the image segmenter
imageSegmenter.segment({} as HTMLImageElement, result => {
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
expect(result.categoryMask).toEqual(categoryMask);
expect(result.confidenceMasks).toEqual([
confidenceMask1, confidenceMask2
]);
expect(result.width).toEqual(1);
expect(result.height).toEqual(1);
resolve(); resolve();
}); });
}); });

View File

@ -30,7 +30,10 @@ mediapipe_ts_library(
mediapipe_ts_declaration( mediapipe_ts_declaration(
name = "interactive_segmenter_types", name = "interactive_segmenter_types",
srcs = ["interactive_segmenter_options.d.ts"], srcs = [
"interactive_segmenter_options.d.ts",
"interactive_segmenter_result.d.ts",
],
deps = [ deps = [
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/core:classifier_options",

View File

@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {Color as ColorProto} from '../../../../util/color_pb'; import {Color as ColorProto} from '../../../../util/color_pb';
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
import {InteractiveSegmenterOptions} from './interactive_segmenter_options'; import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
import {InteractiveSegmenterResult} from './interactive_segmenter_result';
export * from './interactive_segmenter_options'; export * from './interactive_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest}; export * from './interactive_segmenter_result';
export {SegmentationMask, RegionOfInterest};
export {ImageSource}; export {ImageSource};
const IMAGE_IN_STREAM = 'image_in'; const IMAGE_IN_STREAM = 'image_in';
const NORM_RECT_IN_STREAM = 'norm_rect_in'; const NORM_RECT_IN_STREAM = 'norm_rect_in';
const ROI_IN_STREAM = 'roi_in'; const ROI_IN_STREAM = 'roi_in';
const IMAGE_OUT_STREAM = 'image_out'; const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
const CATEGORY_MASK_STREAM = 'category_mask';
const IMAGEA_SEGMENTER_GRAPH = const IMAGEA_SEGMENTER_GRAPH =
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
// The OSS JS API does not support the builder pattern. // The OSS JS API does not support the builder pattern.
// tslint:disable:jspb-use-builder-pattern // tslint:disable:jspb-use-builder-pattern
/**
* A callback that receives the computed masks from the interactive segmenter.
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type InteractiveSegmenterCallack =
(result: InteractiveSegmenterResult) => void;
/** /**
* Performs interactive segmentation on images. * Performs interactive segmentation on images.
* *
@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH =
* - batch is always 1 * - batch is always 1
*/ */
export class InteractiveSegmenter extends VisionTaskRunner { export class InteractiveSegmenter extends VisionTaskRunner {
private userCallback: SegmentationMaskCallback = () => {}; private result: InteractiveSegmenterResult = {width: 0, height: 0};
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
private readonly options: ImageSegmenterGraphOptionsProto; private readonly options: ImageSegmenterGraphOptionsProto;
private readonly segmenterOptions: SegmenterOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto;
@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner {
* @return A Promise that resolves when the settings have been applied. * @return A Promise that resolves when the settings have been applied.
*/ */
override setOptions(options: InteractiveSegmenterOptions): Promise<void> { override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
if (options.outputType === 'CONFIDENCE_MASK') { if ('outputCategoryMask' in options) {
this.segmenterOptions.setOutputType( this.outputCategoryMask =
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
} else { }
this.segmenterOptions.setOutputType(
SegmenterOptionsProto.OutputType.CATEGORY_MASK); if ('outputConfidenceMasks' in options) {
this.outputConfidenceMasks =
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
} }
return super.applyOptions(options); return super.applyOptions(options);
@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
*/ */
segment( segment(
image: ImageSource, roi: RegionOfInterest, image: ImageSource, roi: RegionOfInterest,
callback: SegmentationMaskCallback): void; callback: InteractiveSegmenterCallack): void;
/** /**
* Performs interactive segmentation on the provided single image and invokes * Performs interactive segmentation on the provided single image and invokes
* the callback with the response. The `roi` parameter is used to represent a * the callback with the response. The `roi` parameter is used to represent a
@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner {
segment( segment(
image: ImageSource, roi: RegionOfInterest, image: ImageSource, roi: RegionOfInterest,
imageProcessingOptions: ImageProcessingOptions, imageProcessingOptions: ImageProcessingOptions,
callback: SegmentationMaskCallback): void; callback: InteractiveSegmenterCallack): void;
segment( segment(
image: ImageSource, roi: RegionOfInterest, image: ImageSource, roi: RegionOfInterest,
imageProcessingOptionsOrCallback: ImageProcessingOptions| imageProcessingOptionsOrCallback: ImageProcessingOptions|
SegmentationMaskCallback, InteractiveSegmenterCallack,
callback?: SegmentationMaskCallback): void { callback?: InteractiveSegmenterCallack): void {
const imageProcessingOptions = const imageProcessingOptions =
typeof imageProcessingOptionsOrCallback !== 'function' ? typeof imageProcessingOptionsOrCallback !== 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
{}; {};
const userCallback =
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? typeof imageProcessingOptionsOrCallback === 'function' ?
imageProcessingOptionsOrCallback : imageProcessingOptionsOrCallback :
callback!; callback!;
this.reset();
this.processRenderData(roi, this.getSynctheticTimestamp()); this.processRenderData(roi, this.getSynctheticTimestamp());
this.processImageData(image, imageProcessingOptions); this.processImageData(image, imageProcessingOptions);
this.userCallback = () => {}; userCallback(this.result);
}
private reset(): void {
this.result = {width: 0, height: 0};
} }
/** Updates the MediaPipe graph configuration. */ /** Updates the MediaPipe graph configuration. */
@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner {
graphConfig.addInputStream(IMAGE_IN_STREAM); graphConfig.addInputStream(IMAGE_IN_STREAM);
graphConfig.addInputStream(ROI_IN_STREAM); graphConfig.addInputStream(ROI_IN_STREAM);
graphConfig.addInputStream(NORM_RECT_IN_STREAM); graphConfig.addInputStream(NORM_RECT_IN_STREAM);
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
const calculatorOptions = new CalculatorOptions(); const calculatorOptions = new CalculatorOptions();
calculatorOptions.setExtension( calculatorOptions.setExtension(
@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner {
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM); segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM); segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
segmenterNode.setOptions(calculatorOptions); segmenterNode.setOptions(calculatorOptions);
graphConfig.addNode(segmenterNode); graphConfig.addNode(segmenterNode);
this.graphRunner.attachImageVectorListener( if (this.outputConfidenceMasks) {
IMAGE_OUT_STREAM, (masks, timestamp) => { graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
if (masks.length === 0) { segmenterNode.addOutputStream(
this.userCallback([], 0, 0); 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
} else {
this.userCallback( this.graphRunner.attachImageVectorListener(
masks.map(m => m.data), masks[0].width, masks[0].height); CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
} this.result.confidenceMasks = masks.map(m => m.data);
this.setLatestOutputTimestamp(timestamp); if (masks.length >= 0) {
}); this.result.width = masks[0].width;
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => { this.result.height = masks[0].height;
this.setLatestOutputTimestamp(timestamp); }
});
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CONFIDENCE_MASKS_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
if (this.outputCategoryMask) {
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
this.graphRunner.attachImageListener(
CATEGORY_MASK_STREAM, (mask, timestamp) => {
this.result.categoryMask = mask.data;
this.result.width = mask.width;
this.result.height = mask.height;
this.setLatestOutputTimestamp(timestamp);
});
this.graphRunner.attachEmptyPacketListener(
CATEGORY_MASK_STREAM, timestamp => {
this.setLatestOutputTimestamp(timestamp);
});
}
const binaryGraph = graphConfig.serializeBinary(); const binaryGraph = graphConfig.serializeBinary();
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);

View File

@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'
/** Options to configure the MediaPipe Interactive Segmenter Task */ /** Options to configure the MediaPipe Interactive Segmenter Task */
export interface InteractiveSegmenterOptions extends TaskRunnerOptions { export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
/** /** Whether to output confidence masks. Defaults to true. */
* The output type of segmentation results. outputConfidenceMasks?: boolean|undefined;
*
* The two supported modes are: /** Whether to output the category masks. Defaults to false. */
* - Category Mask: Gives a single output mask where each pixel represents outputCategoryMask?: boolean|undefined;
* the class which the pixel in the original image was
* predicted to belong to.
* - Confidence Mask: Gives a list of output masks (one for each class). For
* each mask, the pixel represents the prediction
* confidence, usually in the [0.0, 0.1] range.
*
* Defaults to `CATEGORY_MASK`.
*/
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
} }

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/** The output result of InteractiveSegmenter. */
export declare interface InteractiveSegmenterResult {
/**
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
* pixel represents the prediction confidence, usually in the [0, 1] range.
*/
confidenceMasks?: Float32Array[]|WebGLTexture[];
/**
* A category mask as a Uint8ClampedArray or WebGLTexture where each
* pixel represents the class which the pixel in the original image was
* predicted to belong to.
*/
categoryMask?: Uint8ClampedArray|WebGLTexture;
/** The width of the masks. */
width: number;
/** The height of the masks. */
height: number;
}

View File

@ -18,7 +18,7 @@ import 'jasmine';
// Placeholder for internal dependency on encodeByteArray // Placeholder for internal dependency on encodeByteArray
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
graph: CalculatorGraphConfig|undefined; graph: CalculatorGraphConfig|undefined;
fakeWasmModule: SpyWasmModule; fakeWasmModule: SpyWasmModule;
imageVectorListener: categoryMaskListener:
((images: WasmImage, timestamp: number) => void)|undefined;
confidenceMasksListener:
((images: WasmImage[], timestamp: number) => void)|undefined; ((images: WasmImage[], timestamp: number) => void)|undefined;
lastRoi?: RenderDataProto; lastRoi?: RenderDataProto;
@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
this.fakeWasmModule = this.fakeWasmModule =
this.graphRunner.wasmModule as unknown as SpyWasmModule; this.graphRunner.wasmModule as unknown as SpyWasmModule;
this.attachListenerSpies[0] = this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
.and.callFake((stream, listener) => {
expect(stream).toEqual('category_mask');
this.categoryMaskListener = listener;
});
this.attachListenerSpies[1] =
spyOn(this.graphRunner, 'attachImageVectorListener') spyOn(this.graphRunner, 'attachImageVectorListener')
.and.callFake((stream, listener) => { .and.callFake((stream, listener) => {
expect(stream).toEqual('image_out'); expect(stream).toEqual('confidence_masks');
this.imageVectorListener = listener; this.confidenceMasksListener = listener;
}); });
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => {
it('initializes graph', async () => { it('initializes graph', async () => {
verifyGraph(interactiveSegmenter); verifyGraph(interactiveSegmenter);
verifyListenersRegistered(interactiveSegmenter);
// Verify default options
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
}); });
it('reloads graph when settings are changed', async () => { it('reloads graph when settings are changed', async () => {
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); await interactiveSegmenter.setOptions(
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]); {outputConfidenceMasks: true, outputCategoryMask: false});
verifyListenersRegistered(interactiveSegmenter); expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); await interactiveSegmenter.setOptions(
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]); {outputConfidenceMasks: false, outputCategoryMask: true});
verifyListenersRegistered(interactiveSegmenter); expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
}); });
it('can use custom models', async () => { it('can use custom models', async () => {
@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => {
]); ]);
}); });
describe('setOptions()', () => {
const fieldPath = ['segmenterOptions', 'outputType'];
it(`can set outputType`, async () => {
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
});
it(`can clear outputType`, async () => {
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
await interactiveSegmenter.setOptions({outputType: undefined});
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
});
});
it('doesn\'t support region of interest', () => { it('doesn\'t support region of interest', () => {
expect(() => { expect(() => {
interactiveSegmenter.segment( interactiveSegmenter.segment(
@ -153,60 +147,99 @@ describe('InteractiveSegmenter', () => {
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
}); });
it('supports category masks', (done) => { it('supports category mask', async () => {
const mask = new Uint8ClampedArray([1, 2, 3, 4]); const mask = new Uint8ClampedArray([1, 2, 3, 4]);
await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: false});
// Pass the test data to our listener // Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(interactiveSegmenter); expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
interactiveSegmenter.imageVectorListener!( interactiveSegmenter.categoryMaskListener!
[ ({data: mask, width: 2, height: 2},
{data: mask, width: 2, height: 2}, /* timestamp= */ 1337);
],
/* timestamp= */ 1337);
}); });
// Invoke the image segmenter // Invoke the image segmenter
interactiveSegmenter.segment( return new Promise<void>(resolve => {
{} as HTMLImageElement, ROI, (masks, width, height) => { interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(masks).toHaveSize(1); expect(result.categoryMask).toEqual(mask);
expect(masks[0]).toEqual(mask); expect(result.confidenceMasks).not.toBeDefined();
expect(width).toEqual(2); expect(result.width).toEqual(2);
expect(height).toEqual(2); expect(result.height).toEqual(2);
done(); resolve();
}); });
});
}); });
it('supports confidence masks', async () => { it('supports confidence masks', async () => {
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); await interactiveSegmenter.setOptions(
{outputCategoryMask: false, outputConfidenceMasks: true});
// Pass the test data to our listener // Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
verifyListenersRegistered(interactiveSegmenter); expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
interactiveSegmenter.imageVectorListener!( interactiveSegmenter.confidenceMasksListener!(
[ [
{data: mask1, width: 2, height: 2}, {data: mask1, width: 2, height: 2},
{data: mask2, width: 2, height: 2}, {data: mask2, width: 2, height: 2},
], ],
1337); 1337);
}); });
return new Promise<void>(resolve => {
// Invoke the image segmenter
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled();
expect(result.categoryMask).not.toBeDefined();
expect(result.confidenceMasks).toEqual([mask1, mask2]);
expect(result.width).toEqual(2);
expect(result.height).toEqual(2);
resolve();
});
});
});
it('supports combined category and confidence masks', async () => {
const categoryMask = new Uint8ClampedArray([1, 0]);
const confidenceMask1 = new Float32Array([0.0, 1.0]);
const confidenceMask2 = new Float32Array([1.0, 0.0]);
await interactiveSegmenter.setOptions(
{outputCategoryMask: true, outputConfidenceMasks: true});
// Pass the test data to our listener
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
interactiveSegmenter.categoryMaskListener!
({data: categoryMask, width: 1, height: 1}, 1337);
interactiveSegmenter.confidenceMasksListener!(
[
{data: confidenceMask1, width: 1, height: 1},
{data: confidenceMask2, width: 1, height: 1},
],
1337);
});
return new Promise<void>(resolve => { return new Promise<void>(resolve => {
// Invoke the image segmenter // Invoke the image segmenter
interactiveSegmenter.segment( interactiveSegmenter.segment(
{} as HTMLImageElement, ROI, (masks, width, height) => { {} as HTMLImageElement, ROI, result => {
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
.toHaveBeenCalled(); .toHaveBeenCalled();
expect(masks).toHaveSize(2); expect(result.categoryMask).toEqual(categoryMask);
expect(masks[0]).toEqual(mask1); expect(result.confidenceMasks).toEqual([
expect(masks[1]).toEqual(mask2); confidenceMask1, confidenceMask2
expect(width).toEqual(2); ]);
expect(height).toEqual(2); expect(result.width).toEqual(1);
expect(result.height).toEqual(1);
resolve(); resolve();
}); });
}); });

View File

@ -56,8 +56,8 @@ bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y,
VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0"; VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0";
} }
*x_px = static_cast<int32>(round(normalized_x * image_width)); *x_px = static_cast<int32_t>(round(normalized_x * image_width));
*y_px = static_cast<int32>(round(normalized_y * image_height)); *y_px = static_cast<int32_t>(round(normalized_y * image_height));
return true; return true;
} }

View File

@ -43,7 +43,7 @@ ABSL_FLAG(std::string, system_cpu_max_freq_file,
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr uint32 kBufferLength = 64; constexpr uint32_t kBufferLength = 64;
absl::StatusOr<std::string> GetFilePath(int cpu) { absl::StatusOr<std::string> GetFilePath(int cpu) {
if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) { if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
@ -54,7 +54,7 @@ absl::StatusOr<std::string> GetFilePath(int cpu) {
return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu); return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu);
} }
absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) { absl::StatusOr<uint64_t> GetCpuMaxFrequency(int cpu) {
auto path_or_status = GetFilePath(cpu); auto path_or_status = GetFilePath(cpu);
if (!path_or_status.ok()) { if (!path_or_status.ok()) {
return path_or_status.status(); return path_or_status.status();
@ -65,7 +65,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
char buffer[kBufferLength]; char buffer[kBufferLength];
file.getline(buffer, kBufferLength); file.getline(buffer, kBufferLength);
file.close(); file.close();
uint64 frequency; uint64_t frequency;
if (absl::SimpleAtoi(buffer, &frequency)) { if (absl::SimpleAtoi(buffer, &frequency)) {
return frequency; return frequency;
} else { } else {
@ -79,7 +79,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
} }
std::set<int> InferLowerOrHigherCoreIds(bool lower) { std::set<int> InferLowerOrHigherCoreIds(bool lower) {
std::vector<std::pair<int, uint64>> cpu_freq_pairs; std::vector<std::pair<int, uint64_t>> cpu_freq_pairs;
for (int cpu = 0; cpu < NumCPUCores(); ++cpu) { for (int cpu = 0; cpu < NumCPUCores(); ++cpu) {
auto freq_or_status = GetCpuMaxFrequency(cpu); auto freq_or_status = GetCpuMaxFrequency(cpu);
if (freq_or_status.ok()) { if (freq_or_status.ok()) {
@ -90,12 +90,12 @@ std::set<int> InferLowerOrHigherCoreIds(bool lower) {
return {}; return {};
} }
absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64>& left, absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64_t>& left,
const std::pair<int, uint64>& right) { const std::pair<int, uint64_t>& right) {
return (lower && left.second < right.second) || return (lower && left.second < right.second) ||
(!lower && left.second > right.second); (!lower && left.second > right.second);
}); });
uint64 edge_freq = cpu_freq_pairs[0].second; uint64_t edge_freq = cpu_freq_pairs[0].second;
std::set<int> inferred_cores; std::set<int> inferred_cores;
for (const auto& cpu_freq_pair : cpu_freq_pairs) { for (const auto& cpu_freq_pair : cpu_freq_pairs) {

View File

@ -89,12 +89,12 @@ void ImageFrameToYUVImage(const ImageFrame& image_frame, YUVImage* yuv_image) {
const int uv_stride = (uv_width + 15) & ~15; const int uv_stride = (uv_width + 15) & ~15;
const int y_size = y_stride * height; const int y_size = y_stride * height;
const int uv_size = uv_stride * uv_height; const int uv_size = uv_stride * uv_height;
uint8* data = uint8_t* data =
reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size * 2, 16)); reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size * 2, 16));
std::function<void()> deallocate = [data]() { aligned_free(data); }; std::function<void()> deallocate = [data]() { aligned_free(data); };
uint8* y = data; uint8_t* y = data;
uint8* u = y + y_size; uint8_t* u = y + y_size;
uint8* v = u + uv_size; uint8_t* v = u + uv_size;
yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, // yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, //
y, y_stride, // y, y_stride, //
u, uv_stride, // u, uv_stride, //
@ -123,10 +123,11 @@ void ImageFrameToYUVNV12Image(const ImageFrame& image_frame,
const int uv_stride = y_stride; const int uv_stride = y_stride;
const int uv_height = (height + 1) / 2; const int uv_height = (height + 1) / 2;
const int uv_size = uv_stride * uv_height; const int uv_size = uv_stride * uv_height;
uint8* data = reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size, 16)); uint8_t* data =
reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size, 16));
std::function<void()> deallocate = [data] { aligned_free(data); }; std::function<void()> deallocate = [data] { aligned_free(data); };
uint8* y = data; uint8_t* y = data;
uint8* uv = y + y_size; uint8_t* uv = y + y_size;
yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv, yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv,
uv_stride, nullptr, 0, width, height); uv_stride, nullptr, 0, width, height);
const int rv = libyuv::I420ToNV12( const int rv = libyuv::I420ToNV12(
@ -210,44 +211,44 @@ void YUVImageToImageFrameFromFormat(const YUVImage& yuv_image,
} }
} }
void SrgbToMpegYCbCr(const uint8 r, const uint8 g, const uint8 b, // void SrgbToMpegYCbCr(const uint8_t r, const uint8_t g, const uint8_t b, //
uint8* y, uint8* cb, uint8* cr) { uint8_t* y, uint8_t* cb, uint8_t* cr) {
// ITU-R BT.601 conversion from sRGB to YCbCr. // ITU-R BT.601 conversion from sRGB to YCbCr.
// FastIntRound is used rather than SafeRound since the possible // FastIntRound is used rather than SafeRound since the possible
// range of values is [16,235] for Y and [16,240] for Cb and Cr and we // range of values is [16,235] for Y and [16,240] for Cb and Cr and we
// don't care about the rounding direction for values exactly between // don't care about the rounding direction for values exactly between
// two integers. // two integers.
*y = static_cast<uint8>( *y = static_cast<uint8_t>(
mediapipe::MathUtil::FastIntRound(16.0 + // mediapipe::MathUtil::FastIntRound(16.0 + //
65.481 * r / 255.0 + // 65.481 * r / 255.0 + //
128.553 * g / 255.0 + // 128.553 * g / 255.0 + //
24.966 * b / 255.0)); 24.966 * b / 255.0));
*cb = static_cast<uint8>( *cb = static_cast<uint8_t>(
mediapipe::MathUtil::FastIntRound(128.0 + // mediapipe::MathUtil::FastIntRound(128.0 + //
-37.797 * r / 255.0 + // -37.797 * r / 255.0 + //
-74.203 * g / 255.0 + // -74.203 * g / 255.0 + //
112.0 * b / 255.0)); 112.0 * b / 255.0));
*cr = static_cast<uint8>( *cr = static_cast<uint8_t>(
mediapipe::MathUtil::FastIntRound(128.0 + // mediapipe::MathUtil::FastIntRound(128.0 + //
112.0 * r / 255.0 + // 112.0 * r / 255.0 + //
-93.786 * g / 255.0 + // -93.786 * g / 255.0 + //
-18.214 * b / 255.0)); -18.214 * b / 255.0));
} }
void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, // void MpegYCbCrToSrgb(const uint8_t y, const uint8_t cb, const uint8_t cr, //
uint8* r, uint8* g, uint8* b) { uint8_t* r, uint8_t* g, uint8_t* b) {
// ITU-R BT.601 conversion from YCbCr to sRGB // ITU-R BT.601 conversion from YCbCr to sRGB
// Use SafeRound since many MPEG YCbCr values do not correspond directly // Use SafeRound since many MPEG YCbCr values do not correspond directly
// to an sRGB value. // to an sRGB value.
*r = mediapipe::MathUtil::SafeRound<uint8, double>( // *r = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
255.0 / 219.0 * (y - 16.0) + // 255.0 / 219.0 * (y - 16.0) + //
255.0 / 112.0 * 0.701 * (cr - 128.0)); 255.0 / 112.0 * 0.701 * (cr - 128.0));
*g = mediapipe::MathUtil::SafeRound<uint8, double>( *g = mediapipe::MathUtil::SafeRound<uint8_t, double>(
255.0 / 219.0 * (y - 16.0) - // 255.0 / 219.0 * (y - 16.0) - //
255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - // 255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - //
255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0)); 255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0));
*b = mediapipe::MathUtil::SafeRound<uint8, double>( // *b = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
255.0 / 219.0 * (y - 16.0) + // 255.0 / 219.0 * (y - 16.0) + //
255.0 / 112.0 * 0.886 * (cb - 128.0)); 255.0 / 112.0 * 0.886 * (cb - 128.0));
} }
@ -260,15 +261,15 @@ void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
cv::Mat GetSrgbToLinearRgb16Lut() { cv::Mat GetSrgbToLinearRgb16Lut() {
cv::Mat lut(1, 256, CV_16UC1); cv::Mat lut(1, 256, CV_16UC1);
uint16* ptr = lut.ptr<uint16>(); uint16_t* ptr = lut.ptr<uint16_t>();
constexpr double kUint8Max = 255.0; constexpr double kUint8Max = 255.0;
constexpr double kUint16Max = 65535.0; constexpr double kUint16Max = 65535.0;
for (int i = 0; i < 256; ++i) { for (int i = 0; i < 256; ++i) {
if (i < 0.04045 * kUint8Max) { if (i < 0.04045 * kUint8Max) {
ptr[i] = static_cast<uint16>( ptr[i] = static_cast<uint16_t>(
(static_cast<double>(i) / kUint8Max / 12.92) * kUint16Max + .5); (static_cast<double>(i) / kUint8Max / 12.92) * kUint16Max + .5);
} else { } else {
ptr[i] = static_cast<uint16>( ptr[i] = static_cast<uint16_t>(
pow((static_cast<double>(i) / kUint8Max + 0.055) / 1.055, 2.4) * pow((static_cast<double>(i) / kUint8Max + 0.055) / 1.055, 2.4) *
kUint16Max + kUint16Max +
.5); .5);
@ -279,15 +280,15 @@ cv::Mat GetSrgbToLinearRgb16Lut() {
cv::Mat GetLinearRgb16ToSrgbLut() { cv::Mat GetLinearRgb16ToSrgbLut() {
cv::Mat lut(1, 65536, CV_8UC1); cv::Mat lut(1, 65536, CV_8UC1);
uint8* ptr = lut.ptr<uint8>(); uint8_t* ptr = lut.ptr<uint8_t>();
constexpr double kUint8Max = 255.0; constexpr double kUint8Max = 255.0;
constexpr double kUint16Max = 65535.0; constexpr double kUint16Max = 65535.0;
for (int i = 0; i < 65536; ++i) { for (int i = 0; i < 65536; ++i) {
if (i < 0.0031308 * kUint16Max) { if (i < 0.0031308 * kUint16Max) {
ptr[i] = static_cast<uint8>( ptr[i] = static_cast<uint8_t>(
(static_cast<double>(i) / kUint16Max * 12.92) * kUint8Max + .5); (static_cast<double>(i) / kUint16Max * 12.92) * kUint8Max + .5);
} else { } else {
ptr[i] = static_cast<uint8>( ptr[i] = static_cast<uint8_t>(
(1.055 * pow(static_cast<double>(i) / kUint16Max, 1.0 / 2.4) - .055) * (1.055 * pow(static_cast<double>(i) / kUint16Max, 1.0 / 2.4) - .055) *
kUint8Max + kUint8Max +
.5); .5);
@ -306,13 +307,13 @@ void LinearRgb16ToSrgb(const cv::Mat& source, cv::Mat* destination) {
destination->create(source.size(), CV_8UC(source.channels())); destination->create(source.size(), CV_8UC(source.channels()));
static const cv::Mat kLut = GetLinearRgb16ToSrgbLut(); static const cv::Mat kLut = GetLinearRgb16ToSrgbLut();
const uint8* lookup_table_ptr = kLut.ptr<uint8>(); const uint8_t* lookup_table_ptr = kLut.ptr<uint8_t>();
const int num_channels = source.channels(); const int num_channels = source.channels();
for (int row = 0; row < source.rows; ++row) { for (int row = 0; row < source.rows; ++row) {
for (int col = 0; col < source.cols; ++col) { for (int col = 0; col < source.cols; ++col) {
for (int channel = 0; channel < num_channels; ++channel) { for (int channel = 0; channel < num_channels; ++channel) {
uint8* ptr = destination->ptr<uint8>(row); uint8_t* ptr = destination->ptr<uint8_t>(row);
const uint16* ptr16 = source.ptr<uint16>(row); const uint16_t* ptr16 = source.ptr<uint16_t>(row);
ptr[col * num_channels + channel] = ptr[col * num_channels + channel] =
lookup_table_ptr[ptr16[col * num_channels + channel]]; lookup_table_ptr[ptr16[col * num_channels + channel]];
} }

View File

@ -43,14 +43,14 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
Packet MakeImageFramePacket(cv::Mat input, int timestamp) { Packet MakeImageFramePacket(cv::Mat input, int timestamp) {
ImageFrame input_image(GetImageFormat(input.channels()), input.cols, ImageFrame input_image(GetImageFormat(input.channels()), input.cols,
input.rows, input.step, input.data, [](uint8*) {}); input.rows, input.step, input.data, [](uint8_t*) {});
return MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0)); return MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0));
} }
Packet MakeImagePacket(cv::Mat input, int timestamp) { Packet MakeImagePacket(cv::Mat input, int timestamp) {
mediapipe::Image input_image(std::make_shared<mediapipe::ImageFrame>( mediapipe::Image input_image(std::make_shared<mediapipe::ImageFrame>(
GetImageFormat(input.channels()), input.cols, input.rows, input.step, GetImageFormat(input.channels()), input.cols, input.rows, input.step,
input.data, [](uint8*) {})); input.data, [](uint8_t*) {}));
return MakePacket<mediapipe::Image>(std::move(input_image)).At(Timestamp(0)); return MakePacket<mediapipe::Image>(std::move(input_image)).At(Timestamp(0));
} }

View File

@ -25,7 +25,7 @@
namespace mediapipe { namespace mediapipe {
absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles( absl::StatusOr<proto_ns::Map<int64_t, LabelMapItem>> BuildLabelMapFromFiles(
absl::string_view labels_file_contents, absl::string_view labels_file_contents,
absl::string_view display_names_file) { absl::string_view display_names_file) {
if (labels_file_contents.empty()) { if (labels_file_contents.empty()) {
@ -68,7 +68,7 @@ absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles(
label_map_items[i].set_display_name(display_names[i]); label_map_items[i].set_display_name(display_names[i]);
} }
} }
proto_ns::Map<int64, LabelMapItem> label_map; proto_ns::Map<int64_t, LabelMapItem> label_map;
for (int i = 0; i < label_map_items.size(); ++i) { for (int i = 0; i < label_map_items.size(); ++i) {
label_map[i] = label_map_items[i]; label_map[i] = label_map_items[i];
} }

View File

@ -45,12 +45,16 @@ filegroup(
"include/flatbuffers/bfbs_generator.h", "include/flatbuffers/bfbs_generator.h",
"include/flatbuffers/buffer.h", "include/flatbuffers/buffer.h",
"include/flatbuffers/buffer_ref.h", "include/flatbuffers/buffer_ref.h",
"include/flatbuffers/code_generator.h",
"include/flatbuffers/code_generators.h", "include/flatbuffers/code_generators.h",
"include/flatbuffers/default_allocator.h", "include/flatbuffers/default_allocator.h",
"include/flatbuffers/detached_buffer.h", "include/flatbuffers/detached_buffer.h",
"include/flatbuffers/flatbuffer_builder.h", "include/flatbuffers/flatbuffer_builder.h",
"include/flatbuffers/flatbuffers.h", "include/flatbuffers/flatbuffers.h",
"include/flatbuffers/flatc.h",
"include/flatbuffers/flex_flat_util.h",
"include/flatbuffers/flexbuffers.h", "include/flatbuffers/flexbuffers.h",
"include/flatbuffers/grpc.h",
"include/flatbuffers/hash.h", "include/flatbuffers/hash.h",
"include/flatbuffers/idl.h", "include/flatbuffers/idl.h",
"include/flatbuffers/minireflect.h", "include/flatbuffers/minireflect.h",

View File

@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
def repo(): def repo():
third_party_http_archive( third_party_http_archive(
name = "flatbuffers", name = "flatbuffers",
strip_prefix = "flatbuffers-2.0.6", strip_prefix = "flatbuffers-23.1.21",
sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9", sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
urls = [ urls = [
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v2.0.6.tar.gz", "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
"https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz", "https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
], ],
build_file = "//third_party/flatbuffers:BUILD.bazel", build_file = "//third_party/flatbuffers:BUILD.bazel",
delete = ["build_defs.bzl", "BUILD.bazel"], delete = ["build_defs.bzl", "BUILD.bazel"],

View File

@ -12,72 +12,72 @@ def wasm_files():
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_js", name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
sha256 = "0eca68e2291a548b734bcab5db4c9e6b997e852ea7e19228003b9e2a78c7c646", sha256 = "b810de53d7ccf991b9c70fcdf7e88b5c3f2942ae766436f22be48159b6a7e687",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681328323089931"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681849488227617"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
sha256 = "69bc95af5b783b510ec1842d6fb9594254907d8e1334799c5753164878a7dcac", sha256 = "26d91147e5c6c8a92e0a4ebf59599068a3cff6108847b793ef33ac23e98eddb9",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681328325829340"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681849491546937"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js",
sha256 = "88a0176cc80d6a1eb175a5105df705cf8b8684cf13f6db0a264af0b67b65a22a", sha256 = "b38e37b3024692558eaaba159921fedd3297d1a09bba1c16a06fed327845b0bd",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681328328330829"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681849494099698"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm",
sha256 = "1cc0c3db7d252801be4b090d8bbba61f308cc3dd5efe197319581d3af29495c7", sha256 = "6a8e73d2e926565046e16adf1748f0f8ec5135fafe7eb8b9c83892e64c1a449a",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681328331085637"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681849496451970"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_js", name = "com_google_mediapipe_wasm_text_wasm_internal_js",
sha256 = "d9cd100b6d330d36f7749fe5fc64a2cdd0abb947a0376e6140784cfb0361a4e2", sha256 = "785cba67b623b1dc66dc3621e97fd6b30edccbb408184a3094d0aa68ddd5becb",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681328333442454"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681849498746265"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
sha256 = "30a2fcca630bdad6e99173ea7d0d8c5d7086aedf393d0159fa05bf9d08d4ff65", sha256 = "a858b8a2e8b40e9c936b66566c5aefd396536c4e936459ab9ae7e239621adc14",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681328335803336"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681849501370461"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js",
sha256 = "70ca2bd15c56e0ce7bb10ff2188b4a1f9eafbb657eb9424e4cab8d7b29179871", sha256 = "5292f1442d5e5c037e7cffb78a8c2d71255348ca2c3bd759b314bdbedd5590c2",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681328338162884"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681849503379116"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm",
sha256 = "8221b385905f36a769d7731a0adbe18b681bcb873561890429ca84278c67c3fd", sha256 = "e44b48ab29ee1d8befec804e9a63445c56266b679d19fb476d556ca621f0e493",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681328340808115"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681849505997020"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_js", name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
sha256 = "07692acd8202adafebd35dbcd7e2b8e88a76d4a0e6b9229cb3cad59503eeddc7", sha256 = "205855eba70464a92b9d00e90acac15c51a9f76192f900e697304ac6dea8f714",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681328343147709"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681849508414277"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
sha256 = "03bf553fa6a768b0d70103a5e7d835b6b37371ff44e201c3392f22e0879737c3", sha256 = "c0cbd0df3adb2a9cd1331d14f522d2bae9f8adc9f1b35f92cbbc4b782b190cef",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681328345605574"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681849510936608"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js",
sha256 = "36697be14f921985eac15d1447ec8a260817b05ade1c9bb3ca7e906e0f047ec0", sha256 = "0969812de4d3573198fa2eba4f5b0a7e97e98f97bd4215d876543f4925e57b84",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681328348025082"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681849513292639"],
) )
http_file( http_file(
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm",
sha256 = "103fb145438d61cfecb2e8db3f06b43a5d77a7e3fcea940437fe272227cf2592", sha256 = "f2ab62c3f8dabab0a573dadf5c105ff81a03c29c70f091f8cf273ae030c0a86f",
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681328350709881"], urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681849515999000"],
) )