Merge branch 'google:master' into pose-landmarker-python
This commit is contained in:
commit
39742b6641
|
@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
|||
} else if (packet_options.has_string_value()) {
|
||||
packet.Set<std::string>();
|
||||
} else if (packet_options.has_uint64_value()) {
|
||||
packet.Set<uint64>();
|
||||
packet.Set<uint64_t>();
|
||||
} else if (packet_options.has_classification_list_value()) {
|
||||
packet.Set<ClassificationList>();
|
||||
} else if (packet_options.has_landmark_list_value()) {
|
||||
|
@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
|||
} else if (packet_options.has_string_value()) {
|
||||
packet.Set(MakePacket<std::string>(packet_options.string_value()));
|
||||
} else if (packet_options.has_uint64_value()) {
|
||||
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
|
||||
packet.Set(MakePacket<uint64_t>(packet_options.uint64_value()));
|
||||
} else if (packet_options.has_classification_list_value()) {
|
||||
packet.Set(MakePacket<ClassificationList>(
|
||||
packet_options.classification_list_value()));
|
||||
|
|
|
@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
// Use this when ALLOW/DISALLOW input is provided as a side packet.
|
||||
void RunTimeStep(int64 timestamp, bool stream_payload) {
|
||||
void RunTimeStep(int64_t timestamp, bool stream_payload) {
|
||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
|
||||
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
||||
}
|
||||
|
||||
// Use this when ALLOW/DISALLOW input is provided as an input stream.
|
||||
void RunTimeStep(int64 timestamp, const std::string& control_tag,
|
||||
void RunTimeStep(int64_t timestamp, const std::string& control_tag,
|
||||
bool control) {
|
||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
||||
|
@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) {
|
|||
}
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) {
|
|||
}
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) {
|
|||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
|
|||
)");
|
||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) {
|
|||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) {
|
|||
output_stream: "test_output"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||
|
@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||
constexpr int64 kTimestampValue1 = 43;
|
||||
constexpr int64_t kTimestampValue1 = 43;
|
||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue2 = 44;
|
||||
constexpr int64_t kTimestampValue2 = 44;
|
||||
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
||||
constexpr int64 kTimestampValue3 = 45;
|
||||
constexpr int64_t kTimestampValue3 = 45;
|
||||
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
|||
output_stream: "STATE_CHANGE:state_changed"
|
||||
)");
|
||||
|
||||
constexpr int64 kTimestampValue0 = 42;
|
||||
constexpr int64_t kTimestampValue0 = 42;
|
||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||
|
||||
const std::vector<Packet>& output =
|
||||
|
|
|
@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator);
|
|||
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
|
||||
REGISTER_CALCULATOR(StringToUintCalculator);
|
||||
|
||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
|
||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>;
|
||||
REGISTER_CALCULATOR(StringToInt32Calculator);
|
||||
|
||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
|
||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>;
|
||||
REGISTER_CALCULATOR(StringToUint32Calculator);
|
||||
|
||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
|
||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>;
|
||||
REGISTER_CALCULATOR(StringToInt64Calculator);
|
||||
|
||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
|
||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>;
|
||||
REGISTER_CALCULATOR(StringToUint64Calculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
|
|||
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
|
||||
frame_ptr->Height(), frame_ptr->WidthStep(),
|
||||
const_cast<uint8_t*>(frame_ptr->PixelData()),
|
||||
[](uint8* data){});
|
||||
[](uint8_t* data){});
|
||||
ASSIGN_OR_RETURN(auto result,
|
||||
runner->Run(image_frame, matrix, size, border_mode));
|
||||
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
|
||||
|
|
|
@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
|||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
|
||||
// Record the most recent first kept timestamp on any stream.
|
||||
for (const auto& stream : input_stream_managers_) {
|
||||
int32 queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
||||
? target_queue_size_
|
||||
: trigger_queue_size_ - 1;
|
||||
int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
||||
? target_queue_size_
|
||||
: trigger_queue_size_ - 1;
|
||||
if (stream->QueueSize() > queue_size) {
|
||||
kept_timestamp_ = std::max(
|
||||
kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
|
||||
|
@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
|||
}
|
||||
|
||||
private:
|
||||
int32 trigger_queue_size_;
|
||||
int32 target_queue_size_;
|
||||
int32_t trigger_queue_size_;
|
||||
int32_t target_queue_size_;
|
||||
bool fixed_min_size_;
|
||||
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
|
||||
// the corresponding call to FillInputSet has not yet completed.
|
||||
|
|
|
@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal(
|
|||
// TODO: Investigate this option in more detail, esp. on Safari.
|
||||
attrs.preserveDrawingBuffer = 0;
|
||||
|
||||
// Since the Emscripten canvas target finding function is visible from here,
|
||||
// we hijack findCanvasEventTarget directly for enforcing old Module.canvas
|
||||
// behavior if the user desires, falling back to the new DOM element CSS
|
||||
// selector behavior next if that is specified, and finally just allowing the
|
||||
// lookup to proceed on a null target.
|
||||
// TODO: Ensure this works with all options (in particular,
|
||||
// multithreading options, like the special-case combination of USE_PTHREADS
|
||||
// and OFFSCREEN_FRAMEBUFFER)
|
||||
// clang-format off
|
||||
EM_ASM(
|
||||
let init_once = true;
|
||||
if (init_once) {
|
||||
const cachedFindCanvasEventTarget = findCanvasEventTarget;
|
||||
|
||||
if (typeof cachedFindCanvasEventTarget !== 'function') {
|
||||
if (typeof console !== 'undefined') {
|
||||
console.error('Expected Emscripten global function '
|
||||
+ '"findCanvasEventTarget" not found. WebGL context creation '
|
||||
+ 'may fail.');
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
findCanvasEventTarget = function(target) {
|
||||
if (target == 0) {
|
||||
if (Module && Module.canvas) {
|
||||
return Module.canvas;
|
||||
} else if (Module && Module.canvasCssSelector) {
|
||||
return cachedFindCanvasEventTarget(Module.canvasCssSelector);
|
||||
}
|
||||
if (typeof console !== 'undefined') {
|
||||
console.warn('Module properties canvas and canvasCssSelector not ' +
|
||||
'found during WebGL context creation.');
|
||||
}
|
||||
}
|
||||
// We still go through with the find attempt, although for most use
|
||||
// cases it will not succeed, just in case the user does want to fall-
|
||||
// back.
|
||||
return cachedFindCanvasEventTarget(target);
|
||||
}; // NOLINT: Necessary semicolon.
|
||||
init_once = false;
|
||||
}
|
||||
);
|
||||
// clang-format on
|
||||
|
||||
// Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
|
||||
// looks for our #canvas target in Module.canvas, where we expect it to be.
|
||||
// -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
|
||||
// event target behavior, but it was never supposed to be tapping into our
|
||||
// canvas anyways. See b/278155946 for more background.
|
||||
EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; });
|
||||
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
|
||||
emscripten_webgl_create_context(nullptr, &attrs);
|
||||
emscripten_webgl_create_context("#canvas", &attrs);
|
||||
|
||||
// Check for failure
|
||||
if (context_handle <= 0) {
|
||||
|
|
|
@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
|
|||
int actual_ws = image_frame.WidthStep();
|
||||
int alignment = 0;
|
||||
std::unique_ptr<ImageFrame> temp;
|
||||
const uint8* data = image_frame.PixelData();
|
||||
const uint8_t* data = image_frame.PixelData();
|
||||
|
||||
// Let's see if the pixel data is tightly aligned to one of the alignments
|
||||
// supported by OpenGL, preferring 4 if possible since it's the default.
|
||||
|
|
|
@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
|||
GpuBufferFormat format) {
|
||||
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
|
||||
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
|
||||
auto y_data = std::make_unique<uint8[]>(y_stride * height);
|
||||
auto y_data = std::make_unique<uint8_t[]>(y_stride * height);
|
||||
switch (fourcc) {
|
||||
case libyuv::FOURCC_NV12:
|
||||
case libyuv::FOURCC_NV21: {
|
||||
|
@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
|||
int uv_width = 2 * std::ceil(0.5f * width);
|
||||
int uv_height = std::ceil(0.5f * height);
|
||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||
auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
||||
auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||
yuv_image_ = std::make_shared<YUVImage>(
|
||||
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
|
||||
nullptr, 0, width, height);
|
||||
|
@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
|||
int uv_width = std::ceil(0.5f * width);
|
||||
int uv_height = std::ceil(0.5f * height);
|
||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||
auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
||||
auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
||||
auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||
auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||
yuv_image_ = std::make_shared<YUVImage>(
|
||||
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
|
||||
std::move(v_data), uv_stride, width, height);
|
||||
|
|
|
@ -16,6 +16,7 @@ import csv
|
|||
import filecmp
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock as unittest_mock
|
||||
|
||||
import tensorflow as tf
|
||||
|
@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier
|
|||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
@unittest.skip('b/275624089')
|
||||
class TextClassifierTest(tf.test.TestCase):
|
||||
|
||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||
|
|
|
@ -175,11 +175,7 @@ py_test(
|
|||
data = [":testdata"],
|
||||
tags = ["requires-net:external"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":hyperparameters",
|
||||
":model_spec",
|
||||
":object_detector",
|
||||
":object_detector_options",
|
||||
":object_detector_import",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
|
|||
from absl.testing import parameterized
|
||||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
||||
from mediapipe.model_maker.python.vision import object_detector
|
||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||
|
||||
|
||||
|
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
super().setUp()
|
||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||
cache_dir = self.create_tempdir()
|
||||
self.data = dataset.Dataset.from_coco_folder(
|
||||
self.data = object_detector.Dataset.from_coco_folder(
|
||||
dataset_folder, cache_dir=cache_dir
|
||||
)
|
||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||
|
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.addCleanup(mock_gettempdir.stop)
|
||||
|
||||
def test_object_detector(self):
|
||||
hparams = hyperparameters.HParams(
|
||||
hparams = object_detector.HParams(
|
||||
epochs=1,
|
||||
batch_size=2,
|
||||
learning_rate=0.9,
|
||||
shuffle=False,
|
||||
export_dir=self.create_tempdir(),
|
||||
)
|
||||
options = object_detector_options.ObjectDetectorOptions(
|
||||
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
||||
options = object_detector.ObjectDetectorOptions(
|
||||
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||
hparams=hparams,
|
||||
)
|
||||
# Test `create``
|
||||
model = object_detector.ObjectDetector.create(
|
||||
|
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||
|
||||
# Test `quantization_aware_training`
|
||||
qat_hparams = hyperparameters.QATHParams(
|
||||
qat_hparams = object_detector.QATHParams(
|
||||
learning_rate=0.9,
|
||||
batch_size=2,
|
||||
epochs=1,
|
||||
|
|
|
@ -24,8 +24,8 @@ namespace mediapipe {
|
|||
|
||||
void FrameAnnotationTracker::AddDetectionResult(
|
||||
const FrameAnnotation& frame_annotation) {
|
||||
const int64 time_us =
|
||||
static_cast<int64>(std::round(frame_annotation.timestamp()));
|
||||
const int64_t time_us =
|
||||
static_cast<int64_t>(std::round(frame_annotation.timestamp()));
|
||||
for (const auto& object_annotation : frame_annotation.annotations()) {
|
||||
detected_objects_[time_us + object_annotation.object_id()] =
|
||||
object_annotation;
|
||||
|
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
|
|||
absl::flat_hash_set<int>* cancel_object_ids) {
|
||||
CHECK(cancel_object_ids != nullptr);
|
||||
FrameAnnotation frame_annotation;
|
||||
std::vector<int64> keys_to_be_deleted;
|
||||
std::vector<int64_t> keys_to_be_deleted;
|
||||
for (const auto& detected_obj : detected_objects_) {
|
||||
const int object_id = detected_obj.second.object_id();
|
||||
if (cancel_object_ids->contains(object_id)) {
|
||||
|
|
|
@ -78,6 +78,7 @@ cc_library(
|
|||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
|
||||
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||
|
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
|||
AddCustom("KmeansEmbeddingLookup",
|
||||
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
||||
// For the UniversalSentenceEncoder model.
|
||||
AddCustom("TFSentencepieceTokenizeOp",
|
||||
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
|
||||
AddCustom("RaggedTensorToTensor",
|
||||
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
|
||||
}
|
||||
|
|
|
@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
filegroup(
|
||||
name = "testdata",
|
||||
srcs = glob([
|
||||
"testdata/**",
|
||||
]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "config_fbs",
|
||||
srcs = ["config.fbs"],
|
||||
|
@ -80,3 +87,86 @@ cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sentencepiece_constants",
|
||||
hdrs = ["sentencepiece_constants.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "model_converter",
|
||||
srcs = [
|
||||
"model_converter.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"model_converter.h",
|
||||
],
|
||||
deps = [
|
||||
":config",
|
||||
":double_array_trie_builder",
|
||||
":encoder_config",
|
||||
":sentencepiece_constants",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "optimized_encoder",
|
||||
srcs = [
|
||||
"optimized_encoder.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"optimized_encoder.h",
|
||||
],
|
||||
deps = [
|
||||
":double_array_trie",
|
||||
":encoder_config",
|
||||
":utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sentencepiece_tokenizer_tflite",
|
||||
srcs = ["sentencepiece_tokenizer_tflite.cc"],
|
||||
hdrs = ["sentencepiece_tokenizer_tflite.h"],
|
||||
visibility = [
|
||||
"//visibility:public",
|
||||
],
|
||||
deps =
|
||||
[
|
||||
":optimized_encoder",
|
||||
"@flatbuffers",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite:string_util",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "optimized_encoder_test",
|
||||
srcs = [
|
||||
"optimized_encoder_test.cc",
|
||||
],
|
||||
data = [
|
||||
":testdata",
|
||||
],
|
||||
deps = [
|
||||
":double_array_trie_builder",
|
||||
":encoder_config",
|
||||
":model_converter",
|
||||
":optimized_encoder",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
|
||||
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||
"@org_tensorflow//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
|
||||
#include "src/sentencepiece_model.pb.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
|
||||
DecodePrecompiledCharsmap(
|
||||
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
|
||||
// This function "undoes" encoding done by
|
||||
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
|
||||
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
|
||||
const uint32_t trie_size =
|
||||
*reinterpret_cast<const uint32_t*>(precompiled_map);
|
||||
const uint32_t* trie_ptr =
|
||||
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
|
||||
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
|
||||
precompiled_map + sizeof(uint32_t) + trie_size);
|
||||
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
|
||||
sizeof(uint32_t) - trie_size;
|
||||
return std::make_tuple(
|
||||
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
|
||||
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||
const std::string& model_config_str, int encoding_offset) {
|
||||
::sentencepiece::ModelProto model_config;
|
||||
if (!model_config.ParseFromString(model_config_str)) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Invalid configuration, can't parse SentencePiece model config " +
|
||||
model_config.InitializationErrorString());
|
||||
}
|
||||
// Convert sentencepieces.
|
||||
std::vector<std::string> pieces;
|
||||
pieces.reserve(model_config.pieces_size());
|
||||
std::vector<float> scores;
|
||||
scores.reserve(model_config.pieces_size());
|
||||
std::vector<int> ids;
|
||||
ids.reserve(model_config.pieces_size());
|
||||
float min_score = 0.0;
|
||||
int index = 0;
|
||||
for (const auto& piece : model_config.pieces()) {
|
||||
switch (piece.type()) {
|
||||
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
|
||||
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
|
||||
pieces.push_back(piece.piece());
|
||||
ids.push_back(index);
|
||||
if (piece.score() < min_score) {
|
||||
min_score = piece.score();
|
||||
}
|
||||
break;
|
||||
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
|
||||
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
|
||||
// Ignore unknown and control codes.
|
||||
break;
|
||||
default:
|
||||
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
|
||||
piece.piece());
|
||||
}
|
||||
scores.push_back(piece.score());
|
||||
++index;
|
||||
}
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
|
||||
const auto pieces_score_vector = builder.CreateVector(scores);
|
||||
TrieBuilder pieces_trie_builder(builder);
|
||||
pieces_trie_builder.add_nodes(pieces_trie_vector);
|
||||
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
|
||||
|
||||
// Converting normalization.
|
||||
const auto normalization =
|
||||
DecodePrecompiledCharsmap(model_config.normalizer_spec());
|
||||
const auto normalization_trie = std::get<0>(normalization);
|
||||
const auto normalization_strings = std::get<1>(normalization);
|
||||
const auto normalization_trie_vector =
|
||||
builder.CreateVector(normalization_trie);
|
||||
TrieBuilder normalization_trie_builder(builder);
|
||||
normalization_trie_builder.add_nodes(normalization_trie_vector);
|
||||
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
|
||||
const auto normalization_strings_fbs =
|
||||
builder.CreateVector(normalization_strings);
|
||||
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
|
||||
ecb.add_start_code(model_config.trainer_spec().bos_id());
|
||||
ecb.add_end_code(model_config.trainer_spec().eos_id());
|
||||
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
|
||||
ecb.add_unknown_penalty(min_score - kUnkPenalty);
|
||||
ecb.add_encoding_offset(encoding_offset);
|
||||
ecb.add_pieces(pieces_trie_fbs);
|
||||
ecb.add_pieces_scores(pieces_score_vector);
|
||||
ecb.add_remove_extra_whitespaces(
|
||||
model_config.normalizer_spec().remove_extra_whitespaces());
|
||||
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
|
||||
ecb.add_escape_whitespaces(
|
||||
model_config.normalizer_spec().escape_whitespaces());
|
||||
ecb.add_normalized_prefixes(normalization_trie_fbs);
|
||||
ecb.add_normalized_replacements(normalization_strings_fbs);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||
builder.GetSize());
|
||||
}
|
||||
|
||||
std::string ConvertSentencepieceModel(const std::string& model_string) {
|
||||
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
|
||||
assert(result.status().ok());
|
||||
return result.value();
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,33 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// Converts Sentencepiece configuration to flatbuffer format.
|
||||
// encoding_offset is used by some encoders that combine different encodings.
|
||||
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||
const std::string& model_config_str, int encoding_offset = 0);
|
||||
std::string ConvertSentencepieceModel(const std::string& model_string);
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
|
@ -0,0 +1,236 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
namespace {
|
||||
|
||||
const char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||
|
||||
template <typename processing_callback>
|
||||
std::tuple<std::string, std::vector<int>> process_string(
|
||||
const std::string& input, const std::vector<int>& offsets,
|
||||
const processing_callback& pc) {
|
||||
std::string result_string;
|
||||
result_string.reserve(input.size());
|
||||
std::vector<int> result_offsets;
|
||||
result_offsets.reserve(offsets.size());
|
||||
for (int i = 0, j = 0; i < input.size();) {
|
||||
auto result = pc(input.data() + i, input.size() - i);
|
||||
auto consumed = std::get<0>(result);
|
||||
auto new_string = std::get<1>(result);
|
||||
if (consumed == 0) {
|
||||
// Skip the current byte and move forward.
|
||||
result_string.push_back(input[i]);
|
||||
result_offsets.push_back(offsets[j]);
|
||||
i++;
|
||||
j++;
|
||||
continue;
|
||||
}
|
||||
result_string.append(new_string.data(), new_string.length());
|
||||
for (int i = 0; i < new_string.length(); ++i) {
|
||||
result_offsets.push_back(offsets[j]);
|
||||
}
|
||||
j += consumed;
|
||||
i += consumed;
|
||||
}
|
||||
return std::make_tuple(result_string, result_offsets);
|
||||
}
|
||||
|
||||
inline char is_whitespace(char c) {
|
||||
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
|
||||
}
|
||||
|
||||
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
|
||||
int len) {
|
||||
if (len == 0 || !is_whitespace(*data)) {
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
int num_consumed = 1;
|
||||
for (; num_consumed < len && is_whitespace(data[num_consumed]);
|
||||
++num_consumed) {
|
||||
}
|
||||
return num_consumed > 1
|
||||
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
|
||||
: std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
|
||||
std::tuple<int, utils::string_view> find_replacement(
|
||||
const char* data, int len, const DoubleArrayTrie& dat,
|
||||
const flatbuffers::Vector<int8_t>& replacements) {
|
||||
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
|
||||
if (!max_match.empty()) {
|
||||
// Because flatbuffer byte is signed char which is not the same as char,
|
||||
// there is the reinterpret_cast here.
|
||||
const char* replaced_string_ptr =
|
||||
reinterpret_cast<const char*>(replacements.data() + max_match.id);
|
||||
return std::make_tuple(max_match.match_length,
|
||||
utils::string_view(replaced_string_ptr));
|
||||
}
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||
const std::string& in_string, const EncoderConfig& config) {
|
||||
std::vector<int> output_offsets;
|
||||
std::string result = in_string;
|
||||
output_offsets.reserve(in_string.length());
|
||||
for (int i = 0; i < in_string.length(); ++i) {
|
||||
output_offsets.push_back(i);
|
||||
}
|
||||
if (in_string.empty()) {
|
||||
return std::make_tuple(result, output_offsets);
|
||||
}
|
||||
if (config.add_dummy_prefix()) {
|
||||
result.insert(result.begin(), ' ');
|
||||
output_offsets.insert(output_offsets.begin(), 0);
|
||||
}
|
||||
// Greedely replace normalized_prefixes with normalized_replacements
|
||||
if (config.normalized_prefixes() != nullptr &&
|
||||
config.normalized_replacements() != nullptr) {
|
||||
const DoubleArrayTrie normalized_prefixes_matcher(
|
||||
config.normalized_prefixes()->nodes());
|
||||
const auto norm_replace = [&config, &normalized_prefixes_matcher](
|
||||
const char* data, int len) {
|
||||
return find_replacement(data, len, normalized_prefixes_matcher,
|
||||
*config.normalized_replacements());
|
||||
};
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, norm_replace);
|
||||
}
|
||||
if (config.remove_extra_whitespaces()) {
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, remove_extra_whitespaces);
|
||||
if (!result.empty() && is_whitespace(result.back())) {
|
||||
result.pop_back();
|
||||
output_offsets.pop_back();
|
||||
}
|
||||
}
|
||||
if (config.escape_whitespaces()) {
|
||||
const auto replace_whitespaces = [](const char* data, int len) {
|
||||
if (len > 0 && is_whitespace(*data)) {
|
||||
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
|
||||
}
|
||||
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||
};
|
||||
std::tie(result, output_offsets) =
|
||||
process_string(result, output_offsets, replace_whitespaces);
|
||||
}
|
||||
|
||||
return std::make_tuple(result, output_offsets);
|
||||
}
|
||||
|
||||
EncoderResult EncodeNormalizedString(const std::string& str,
|
||||
const std::vector<int>& offsets,
|
||||
const EncoderConfig& config, bool add_bos,
|
||||
bool add_eos, bool reverse) {
|
||||
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
|
||||
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
|
||||
const int unknown_code = config.unknown_code();
|
||||
const float unknown_penalty = config.unknown_penalty();
|
||||
struct LatticeElement {
|
||||
float score = 0;
|
||||
int code = -1;
|
||||
int prev_position = -1;
|
||||
LatticeElement(float score_, int code_, int prev_position_)
|
||||
: score(score_), code(code_), prev_position(prev_position_) {}
|
||||
LatticeElement() {}
|
||||
};
|
||||
const int length = str.length();
|
||||
std::vector<LatticeElement> lattice(length + 1);
|
||||
for (int i = 0; i < length; ++i) {
|
||||
if (i > 0 && lattice[i].prev_position < 0) {
|
||||
// This state is unreachable.
|
||||
continue;
|
||||
}
|
||||
if (unknown_code >= 0) {
|
||||
// Put unknown code.
|
||||
const float penalized_score = lattice[i].score + unknown_penalty;
|
||||
const int pos = i + 1;
|
||||
LatticeElement& current_element = lattice[pos];
|
||||
if (current_element.prev_position < 0 ||
|
||||
current_element.score < penalized_score) {
|
||||
current_element = LatticeElement(
|
||||
penalized_score, unknown_code,
|
||||
// If the current state is already reached by unknown code, merge
|
||||
// states.
|
||||
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
|
||||
}
|
||||
}
|
||||
auto lattice_update = [&lattice, i,
|
||||
piece_scores](const DoubleArrayTrie::Match& m) {
|
||||
LatticeElement& target_element = lattice[i + m.match_length];
|
||||
const float score = lattice[i].score + (*piece_scores)[m.id];
|
||||
if (target_element.prev_position < 0 || target_element.score < score) {
|
||||
target_element = LatticeElement(score, m.id, i);
|
||||
}
|
||||
};
|
||||
piece_matcher.IteratePrefixMatches(
|
||||
utils::string_view(str.data() + i, length - i), lattice_update);
|
||||
}
|
||||
|
||||
EncoderResult result;
|
||||
if (add_eos) {
|
||||
result.codes.push_back(config.end_code());
|
||||
result.offsets.push_back(length);
|
||||
}
|
||||
if (lattice[length].prev_position >= 0) {
|
||||
for (int pos = length; pos > 0;) {
|
||||
auto code = lattice[pos].code;
|
||||
if (code != config.unknown_code()) {
|
||||
code += config.encoding_offset();
|
||||
}
|
||||
result.codes.push_back(code);
|
||||
pos = lattice[pos].prev_position;
|
||||
result.offsets.push_back(offsets[pos]);
|
||||
}
|
||||
}
|
||||
if (add_bos) {
|
||||
result.codes.push_back(config.start_code());
|
||||
result.offsets.push_back(0);
|
||||
}
|
||||
if (!reverse) {
|
||||
std::reverse(result.codes.begin(), result.codes.end());
|
||||
std::reverse(result.offsets.begin(), result.offsets.end());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||
bool add_bos, bool add_eos, bool reverse) {
|
||||
// Get the config from the buffer.
|
||||
const EncoderConfig* config = GetEncoderConfig(config_buffer);
|
||||
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
|
||||
EncoderResult result;
|
||||
result.type = EncoderResultType::WRONG_CONFIG;
|
||||
return result;
|
||||
}
|
||||
std::string normalized_string;
|
||||
std::vector<int> offsets;
|
||||
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
|
||||
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
|
||||
add_eos, reverse);
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,46 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||
|
||||
// Sentencepiece encoder optimized with memmapped model.
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
|
||||
|
||||
struct EncoderResult {
|
||||
EncoderResultType type = EncoderResultType::SUCCESS;
|
||||
std::vector<int> codes;
|
||||
std::vector<int> offsets;
|
||||
};
|
||||
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||
const std::string& in_string, const EncoderConfig& config);
|
||||
|
||||
// Encodes one string and returns ids and offsets. Takes the configuration as a
|
||||
// type-erased buffer.
|
||||
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||
bool add_bos, bool add_eos, bool reverse);
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
|
@ -0,0 +1,171 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||
#include "src/sentencepiece.pb.h"
|
||||
#include "src/sentencepiece_processor.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
namespace internal {
|
||||
|
||||
tensorflow::Status TFReadFileToString(const std::string& filepath,
|
||||
std::string* data) {
|
||||
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
|
||||
data);
|
||||
}
|
||||
|
||||
absl::Status StdReadFileToString(const std::string& filepath,
|
||||
std::string* data) {
|
||||
std::ifstream infile(filepath);
|
||||
if (!infile.is_open()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrFormat("Error when opening %s", filepath));
|
||||
}
|
||||
std::string contents((std::istreambuf_iterator<char>(infile)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
data->append(contents);
|
||||
infile.close();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} // namespace internal
|
||||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
|
||||
static char kConfigFilePath[] =
|
||||
"/mediapipe/tasks/cc/text/custom_ops/"
|
||||
"sentencepiece/testdata/sentencepiece.model";
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(true);
|
||||
ecb.add_add_dummy_prefix(true);
|
||||
ecb.add_escape_whitespaces(true);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("x y", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
|
||||
}
|
||||
{
|
||||
const auto result = NormalizeString("\tx y\n", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringReplacement) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
|
||||
const char norm_replacements[] = "A1\0A2\0A3\0A4";
|
||||
const auto trie_vector =
|
||||
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
|
||||
const auto norm_r = builder.CreateVector<int8_t>(
|
||||
reinterpret_cast<const signed char*>(norm_replacements),
|
||||
sizeof(norm_replacements));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto norm_p = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(false);
|
||||
ecb.add_normalized_prefixes(norm_p);
|
||||
ecb.add_normalized_replacements(norm_r);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("ABAABAAABAAAA", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, "A1BA2BA3BA4");
|
||||
EXPECT_THAT(offsets,
|
||||
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
|
||||
"X"};
|
||||
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
|
||||
const auto trie_vector =
|
||||
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
|
||||
const auto norm_r = builder.CreateVector<int8_t>(
|
||||
reinterpret_cast<const signed char*>(norm_replacements),
|
||||
sizeof(norm_replacements));
|
||||
TrieBuilder trie_builder(builder);
|
||||
trie_builder.add_nodes(trie_vector);
|
||||
const auto norm_p = trie_builder.Finish();
|
||||
EncoderConfigBuilder ecb(builder);
|
||||
ecb.add_remove_extra_whitespaces(true);
|
||||
ecb.add_normalized_prefixes(norm_p);
|
||||
ecb.add_normalized_replacements(norm_r);
|
||||
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||
{
|
||||
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
|
||||
const auto res_string = std::get<0>(result);
|
||||
const auto offsets = std::get<1>(result);
|
||||
EXPECT_EQ(res_string, " A1BA2BA3BA4");
|
||||
EXPECT_THAT(offsets,
|
||||
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(OptimizedEncoder, ConfigConverter) {
|
||||
std::string config;
|
||||
auto status =
|
||||
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
::sentencepiece::SentencePieceProcessor processor;
|
||||
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
|
||||
const auto converted_model = ConvertSentencepieceModel(config);
|
||||
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
|
||||
const auto encoded =
|
||||
EncodeString(test_string, converted_model.data(), false, false, false);
|
||||
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
|
||||
|
||||
::sentencepiece::SentencePieceText reference_encoded;
|
||||
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
|
||||
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
|
||||
for (int i = 0; i < encoded.codes.size(); ++i) {
|
||||
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
|
||||
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,38 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||
|
||||
namespace mediapipe::tflite_operations::sentencepiece {
|
||||
|
||||
// The constant is copied from
|
||||
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
|
||||
constexpr float kUnkPenalty = 10.0;
|
||||
|
||||
// These constants are copied from
|
||||
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
|
||||
//
|
||||
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
|
||||
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||
|
||||
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
|
||||
// since this character can be useful both for user and
|
||||
// developer. We can easily figure out that <unk> is emitted.
|
||||
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
|
||||
|
||||
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
|
@ -0,0 +1,129 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||
|
||||
#include "flatbuffers/flexbuffers.h"
|
||||
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/context.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
|
||||
namespace mediapipe::tflite_operations {
|
||||
namespace sentencepiece::tokenizer {
|
||||
namespace {
|
||||
|
||||
using ::tflite::SetTensorToDynamic;
|
||||
|
||||
constexpr int kSPModelIndex = 0;
|
||||
constexpr int kInputIndex = 1;
|
||||
constexpr int kAddBOSInput = 4;
|
||||
constexpr int kAddEOSInput = 5;
|
||||
constexpr int kReverseInput = 6;
|
||||
|
||||
constexpr int kOutputValuesInd = 0;
|
||||
constexpr int kOutputSplitsInd = 1;
|
||||
|
||||
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
|
||||
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
|
||||
int index = 0;
|
||||
for (const int size : sizes) {
|
||||
array_size->data[index++] = size;
|
||||
}
|
||||
return array_size;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Initializes text encoder object from serialized parameters.
|
||||
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
|
||||
size_t /*length*/) {
|
||||
return nullptr;
|
||||
}
|
||||
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
// TODO: Add checks for input and output tensors.
|
||||
TfLiteTensor& output_values =
|
||||
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||
SetTensorToDynamic(&output_values);
|
||||
|
||||
TfLiteTensor& output_splits =
|
||||
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||
SetTensorToDynamic(&output_splits);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor& model_tensor =
|
||||
context->tensors[node->inputs->data[kSPModelIndex]];
|
||||
const auto model_buffer_data = model_tensor.data.data;
|
||||
const TfLiteTensor& input_text =
|
||||
context->tensors[node->inputs->data[kInputIndex]];
|
||||
|
||||
const TfLiteTensor add_bos_tensor =
|
||||
context->tensors[node->inputs->data[kAddBOSInput]];
|
||||
const bool add_bos = add_bos_tensor.data.b[0];
|
||||
const TfLiteTensor add_eos_tensor =
|
||||
context->tensors[node->inputs->data[kAddEOSInput]];
|
||||
const bool add_eos = add_eos_tensor.data.b[0];
|
||||
const TfLiteTensor reverse_tensor =
|
||||
context->tensors[node->inputs->data[kReverseInput]];
|
||||
const bool reverse = reverse_tensor.data.b[0];
|
||||
|
||||
std::vector<int32> encoded;
|
||||
std::vector<int32> splits;
|
||||
const int num_strings = tflite::GetStringCount(&input_text);
|
||||
for (int i = 0; i < num_strings; ++i) {
|
||||
const auto strref = tflite::GetString(&input_text, i);
|
||||
const auto res = EncodeString(std::string(strref.str, strref.len),
|
||||
model_buffer_data, add_bos, add_eos, reverse);
|
||||
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
|
||||
"Sentencepiece conversion failed");
|
||||
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
|
||||
splits.emplace_back(encoded.size());
|
||||
}
|
||||
|
||||
TfLiteTensor& output_values =
|
||||
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(
|
||||
context, &output_values,
|
||||
CreateSizeArray({static_cast<int>(encoded.size())})));
|
||||
int32_t* output_values_flat = output_values.data.i32;
|
||||
std::copy(encoded.begin(), encoded.end(), output_values_flat);
|
||||
TfLiteTensor& output_splits =
|
||||
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->ResizeTensor(
|
||||
context, &output_splits,
|
||||
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
|
||||
int32_t* output_splits_flat = output_splits.data.i32;
|
||||
*output_splits_flat = 0;
|
||||
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
} // namespace sentencepiece::tokenizer
|
||||
|
||||
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
|
||||
static TfLiteRegistration r = {
|
||||
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
|
||||
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
|
||||
return &r;
|
||||
}
|
||||
|
||||
} // namespace mediapipe::tflite_operations
|
|
@ -0,0 +1,27 @@
|
|||
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
||||
namespace mediapipe::tflite_operations {
|
||||
|
||||
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
|
||||
|
||||
} // namespace mediapipe::tflite_operations
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
Binary file not shown.
|
@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
|
|||
// Embedding model with regex preprocessing.
|
||||
constexpr char kRegexOneEmbeddingModel[] =
|
||||
"regex_one_embedding_with_metadata.tflite";
|
||||
constexpr char kUniversalSentenceEncoderModel[] =
|
||||
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||
|
||||
// Tolerance for embedding vector coordinate values.
|
||||
constexpr float kEpsilon = 1e-4;
|
||||
|
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
|||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result0,
|
||||
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||
ASSERT_EQ(result0.embeddings.size(), 1);
|
||||
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
|
||||
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||
ASSERT_EQ(result1.embeddings.size(), 1);
|
||||
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
|
||||
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
|
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
|||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
|
||||
auto options = std::make_unique<TextEmbedderOptions>();
|
||||
options->base_options.model_asset_path =
|
||||
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||
TextEmbedder::Create(std::move(options)));
|
||||
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result0,
|
||||
text_embedder->Embed("When you go to this restaurant, they hold the "
|
||||
"pancake upside-down before they hand it "
|
||||
"to you. It's a great gimmick."));
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
TextEmbedderResult result1,
|
||||
text_embedder->Embed(
|
||||
"Let's make a plan to steal the declaration of independence."));
|
||||
|
||||
// Check cosine similarity.
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||
result1.embeddings[0]));
|
||||
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
|
||||
|
||||
MP_ASSERT_OK(text_embedder->Close());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe::tasks::text::text_embedder
|
||||
|
|
|
@ -23,18 +23,12 @@ cc_library(
|
|||
srcs = ["face_stylizer_graph.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/image:image_cropping_calculator",
|
||||
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
|
||||
"//mediapipe/calculators/image:warp_affine_calculator",
|
||||
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
|
||||
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:inference_calculator",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/calculators/util:face_to_rect_calculator",
|
||||
"//mediapipe/calculators/util:from_image_calculator",
|
||||
"//mediapipe/calculators/util:inverse_matrix_calculator",
|
||||
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:to_image_calculator",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
|
@ -53,7 +47,6 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
||||
|
|
|
@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The input image can be of any size with format RGB or RGBA.
|
||||
// When no face is detected on the input image, the method returns a
|
||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||
// face. To ensure that the output image has reasonable quality, the stylized
|
||||
// output image size is the smaller of the model output size and the size of
|
||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
||||
// face. The stylized output image size is the same as the model output size.
|
||||
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
|
||||
mediapipe::Image image,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
|
@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// must be monotonically increasing.
|
||||
// When no face is detected on the input image, the method returns a
|
||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||
// face. To ensure that the output image has reasonable quality, the stylized
|
||||
// output image size is the smaller of the model output size and the size of
|
||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
||||
// face. The stylized output image size is the same as the model output size.
|
||||
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
|
||||
mediapipe::Image image, int64_t timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||
|
@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The "result_callback" provides:
|
||||
// - When no face is detected on the input image, the method returns a
|
||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||
// face. To ensure that the output image has reasonable quality, the
|
||||
// stylized output image size is the smaller of the model output size and
|
||||
// the size of the 'region_of_interest' specified in
|
||||
// 'image_processing_options'.
|
||||
// face. The stylized output image size is the same as the model output
|
||||
// size.
|
||||
// - The input timestamp in milliseconds.
|
||||
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
|
||||
std::optional<core::ImageProcessingOptions>
|
||||
|
|
|
@ -19,8 +19,7 @@ limitations under the License.
|
|||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
|
||||
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
|
@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
|||
image_in >> preprocessing.In(kImageTag);
|
||||
face_rect >> preprocessing.In(kNormRectTag);
|
||||
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
||||
auto transform_matrix = preprocessing.Out(kMatrixTag);
|
||||
|
||||
// Adds inference subgraph and connects its input stream to the output
|
||||
// tensors produced by the ImageToTensorCalculator.
|
||||
|
@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
|||
model_output_tensors >> tensors_to_image.In(kTensorsTag);
|
||||
auto tensor_image = tensors_to_image.Out(kImageTag);
|
||||
|
||||
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
|
||||
transform_matrix >> inverse_matrix.In(kMatrixTag);
|
||||
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag);
|
||||
auto& image_converter = graph.AddNode("ImageCloneCalculator");
|
||||
image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
|
||||
.set_output_on_gpu(false);
|
||||
tensor_image >> image_converter.In("");
|
||||
|
||||
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
|
||||
auto& warp_affine_options =
|
||||
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
|
||||
warp_affine_options.set_border_mode(
|
||||
WarpAffineCalculatorOptions::BORDER_ZERO);
|
||||
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
|
||||
tensor_image >> warp_affine.In(kImageTag);
|
||||
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
|
||||
image_size >> warp_affine.In(kOutputSizeTag);
|
||||
auto image_to_crop = warp_affine.Out(kImageTag);
|
||||
|
||||
// The following calculators are for cropping and resizing the output image
|
||||
// based on the roi and the model output size. As the WarpAffineCalculator
|
||||
// rotates the image based on the transform matrix, the rotation info in the
|
||||
// rect proto is stripped to prevent the ImageCroppingCalculator from
|
||||
// performing extra rotation.
|
||||
auto& strip_rotation =
|
||||
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
|
||||
face_rect >> strip_rotation.In(kNormRectTag);
|
||||
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
|
||||
auto& from_image = graph.AddNode("FromImageCalculator");
|
||||
image_to_crop >> from_image.In(kImageTag);
|
||||
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
|
||||
auto& image_cropping_opts =
|
||||
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
|
||||
image_cropping_opts.set_output_max_width(
|
||||
image_to_tensor_options.output_tensor_width());
|
||||
image_cropping_opts.set_output_max_height(
|
||||
image_to_tensor_options.output_tensor_height());
|
||||
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
|
||||
auto& to_image = graph.AddNode("ToImageCalculator");
|
||||
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
|
||||
// graph selects its cpu or gpu path based on the image preprocessing
|
||||
// backend.
|
||||
if (use_gpu) {
|
||||
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
|
||||
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
|
||||
} else {
|
||||
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
|
||||
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
|
||||
}
|
||||
|
||||
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
|
||||
return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
|
||||
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
|
||||
}
|
||||
};
|
||||
|
|
|
@ -100,6 +100,7 @@ cc_library(
|
|||
"//mediapipe/util:graph_builder_utils",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
@ -90,7 +90,7 @@ NS_SWIFT_NAME(ClassificationResult)
|
|||
* amount of data to process might exceed the maximum size that the model can process: to solve
|
||||
* this, the input data is split into multiple chunks starting at different timestamps.
|
||||
*/
|
||||
@property(nonatomic, readonly) NSInteger timestampMs;
|
||||
@property(nonatomic, readonly) NSInteger timestampInMilliseconds;
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPClassificationResult` with the given array of classifications and time
|
||||
|
@ -98,14 +98,15 @@ NS_SWIFT_NAME(ClassificationResult)
|
|||
*
|
||||
* @param classifications An Array of `MPPClassifications` objects containing the predicted
|
||||
* categories for each head of the model.
|
||||
* @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) of the start of the chunk of data
|
||||
* corresponding to these results.
|
||||
*
|
||||
* @return An instance of `MPPClassificationResult` initialized with the given array of
|
||||
* classifications and timestampMs.
|
||||
* classifications and timestamp (in milliseconds).
|
||||
*/
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -38,11 +38,11 @@
|
|||
@implementation MPPClassificationResult
|
||||
|
||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||
timestampMs:(NSInteger)timestampMs {
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_classifications = classifications;
|
||||
_timestampMs = timestampMs;
|
||||
_timestampInMilliseconds = timestampInMilliseconds;
|
||||
}
|
||||
|
||||
return self;
|
||||
|
|
|
@ -33,7 +33,7 @@ NS_SWIFT_NAME(EmbeddingResult)
|
|||
* cases, the amount of data to process might exceed the maximum size that the model can process. To
|
||||
* solve this, the input data is split into multiple chunks starting at different timestamps.
|
||||
*/
|
||||
@property(nonatomic, readonly) NSInteger timestampMs;
|
||||
@property(nonatomic, readonly) NSInteger timestampInMilliseconds;
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in
|
||||
|
@ -41,14 +41,14 @@ NS_SWIFT_NAME(EmbeddingResult)
|
|||
*
|
||||
* @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each
|
||||
* head of the model.
|
||||
* @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data
|
||||
* corresponding to these results. Pass `0` if timestamp is absent.
|
||||
* @param timestampInMilliseconds The optional timestamp (in milliseconds) of the start of the chunk
|
||||
* of data corresponding to these results. Pass `0` if timestamp is absent.
|
||||
*
|
||||
* @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and
|
||||
* timestampMs.
|
||||
* timestamp (in milliseconds).
|
||||
*/
|
||||
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
|
||||
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -17,11 +17,11 @@
|
|||
@implementation MPPEmbeddingResult
|
||||
|
||||
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
|
||||
timestampMs:(NSInteger)timestampMs {
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_embeddings = embeddings;
|
||||
_timestampMs = timestampMs;
|
||||
_timestampInMilliseconds = timestampInMilliseconds;
|
||||
}
|
||||
|
||||
return self;
|
||||
|
|
|
@ -55,13 +55,13 @@ using ClassificationResultProto =
|
|||
[classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]];
|
||||
}
|
||||
|
||||
NSInteger timestampMs = 0;
|
||||
NSInteger timestampInMilliseconds = 0;
|
||||
if (classificationResultProto.has_timestamp_ms()) {
|
||||
timestampMs = (NSInteger)classificationResultProto.timestamp_ms();
|
||||
timestampInMilliseconds = (NSInteger)classificationResultProto.timestamp_ms();
|
||||
}
|
||||
|
||||
return [[MPPClassificationResult alloc] initWithClassifications:classifications
|
||||
timestampMs:timestampMs];
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,12 +31,13 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::
|
|||
[embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]];
|
||||
}
|
||||
|
||||
NSInteger timestampMs = 0;
|
||||
NSInteger timestampInMilliseconds = 0;
|
||||
if (embeddingResultProto.has_timestamp_ms()) {
|
||||
timestampMs = (NSInteger)embeddingResultProto.timestamp_ms();
|
||||
timestampInMilliseconds = (NSInteger)embeddingResultProto.timestamp_ms();
|
||||
}
|
||||
|
||||
return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs];
|
||||
return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -26,11 +26,12 @@ NS_SWIFT_NAME(TaskResult)
|
|||
/**
|
||||
* Timestamp that is associated with the task result object.
|
||||
*/
|
||||
@property(nonatomic, assign, readonly) NSInteger timestampMs;
|
||||
@property(nonatomic, assign, readonly) NSInteger timestampInMilliseconds;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
|
||||
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -16,16 +16,16 @@
|
|||
|
||||
@implementation MPPTaskResult
|
||||
|
||||
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs {
|
||||
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_timestampMs = timestampMs;
|
||||
_timestampInMilliseconds = timestampInMilliseconds;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (id)copyWithZone:(NSZone *)zone {
|
||||
return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs];
|
||||
return [[MPPTaskResult alloc] initWithTimestampInMilliseconds:self.timestampInMilliseconds];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -487,7 +487,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
|
||||
NSError *liveStreamApiCallError;
|
||||
XCTAssertFalse([imageClassifier classifyAsyncImage:image
|
||||
timestampMs:0
|
||||
timestampInMilliseconds:0
|
||||
error:&liveStreamApiCallError]);
|
||||
|
||||
NSError *expectedLiveStreamApiCallError =
|
||||
|
@ -501,7 +501,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
|
||||
|
||||
NSError *videoApiCallError;
|
||||
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]);
|
||||
XCTAssertFalse([imageClassifier classifyVideoFrame:image
|
||||
timestampInMilliseconds:0
|
||||
error:&videoApiCallError]);
|
||||
|
||||
NSError *expectedVideoApiCallError =
|
||||
[NSError errorWithDomain:kExpectedErrorDomain
|
||||
|
@ -524,7 +526,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
|
||||
NSError *liveStreamApiCallError;
|
||||
XCTAssertFalse([imageClassifier classifyAsyncImage:image
|
||||
timestampMs:0
|
||||
timestampInMilliseconds:0
|
||||
error:&liveStreamApiCallError]);
|
||||
|
||||
NSError *expectedLiveStreamApiCallError =
|
||||
|
@ -575,7 +577,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
|
||||
|
||||
NSError *videoApiCallError;
|
||||
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]);
|
||||
XCTAssertFalse([imageClassifier classifyVideoFrame:image
|
||||
timestampInMilliseconds:0
|
||||
error:&videoApiCallError]);
|
||||
|
||||
NSError *expectedVideoApiCallError =
|
||||
[NSError errorWithDomain:kExpectedErrorDomain
|
||||
|
@ -601,7 +605,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image
|
||||
timestampMs:i
|
||||
timestampInMilliseconds:i
|
||||
error:nil];
|
||||
[self assertImageClassifierResult:imageClassifierResult
|
||||
hasExpectedCategoriesCount:maxResults
|
||||
|
@ -630,10 +634,10 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
|
||||
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
|
||||
|
||||
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:1 error:nil]);
|
||||
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]);
|
||||
|
||||
NSError *error;
|
||||
XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampMs:0 error:&error]);
|
||||
XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampInMilliseconds:0 error:&error]);
|
||||
|
||||
NSError *expectedError =
|
||||
[NSError errorWithDomain:kExpectedErrorDomain
|
||||
|
@ -668,7 +672,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
|||
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:i error:nil]);
|
||||
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextClassifierResult)
|
|||
*
|
||||
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
|
||||
* per classifier head.
|
||||
* @param timestampMs The timestamp for this result.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||
*
|
||||
* @return An instance of `MPPTextClassifierResult` initialized with the given
|
||||
* `MPPClassificationResult` and timestamp (in milliseconds).
|
||||
*/
|
||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||
timestampMs:(NSInteger)timestampMs;
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
@implementation MPPTextClassifierResult
|
||||
|
||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||
timestampMs:(NSInteger)timestampMs {
|
||||
self = [super initWithTimestampMs:timestampMs];
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
if (self) {
|
||||
_classificationResult = classificationResult;
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ using ::mediapipe::Packet;
|
|||
|
||||
return [[MPPTextClassifierResult alloc]
|
||||
initWithClassificationResult:classificationResult
|
||||
timestampMs:(NSInteger)(packet.Timestamp().Value() /
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
}
|
||||
|
||||
|
|
|
@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextEmbedderResult)
|
|||
*
|
||||
* @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results
|
||||
* per classifier head.
|
||||
* @param timestampMs The timestamp for this result.
|
||||
* @param timestampInMilliseconds The timestamp (in millisecondss) for this result.
|
||||
*
|
||||
* @return An instance of `MPPTextEmbedderResult` initialized with the given
|
||||
* `MPPEmbeddingResult` and timestamp (in milliseconds).
|
||||
*/
|
||||
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
|
||||
timestampMs:(NSInteger)timestampMs;
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
@implementation MPPTextEmbedderResult
|
||||
|
||||
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
|
||||
timestampMs:(NSInteger)timestampMs {
|
||||
self = [super initWithTimestampMs:timestampMs];
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
if (self) {
|
||||
_embeddingResult = embeddingResult;
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ using ::mediapipe::Packet;
|
|||
|
||||
return [[MPPTextEmbedderResult alloc]
|
||||
initWithEmbeddingResult:embeddingResult
|
||||
timestampMs:(NSInteger)(packet.Timestamp().Value() /
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@
|
|||
* timestamp.
|
||||
*
|
||||
* @param image The image to send to the MediaPipe graph.
|
||||
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
|
||||
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
|
||||
* error will be saved.
|
||||
*
|
||||
|
@ -49,7 +49,7 @@
|
|||
* occurred during the conversion.
|
||||
*/
|
||||
+ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error;
|
||||
|
||||
/**
|
||||
|
@ -66,11 +66,11 @@
|
|||
* specified timestamp.
|
||||
*
|
||||
* @param image The `NormalizedRect` to send to the MediaPipe graph.
|
||||
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
|
||||
*
|
||||
* @return The MediaPipe packet containing the normalized rect.
|
||||
*/
|
||||
+ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect
|
||||
timestampMs:(NSInteger)timestampMs;
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||
|
||||
@end
|
||||
|
|
|
@ -42,7 +42,7 @@ using ::mediapipe::Timestamp;
|
|||
}
|
||||
|
||||
+ (Packet)createPacketWithMPPImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error {
|
||||
std::unique_ptr<ImageFrame> imageFrame = [image imageFrameWithError:error];
|
||||
|
||||
|
@ -51,7 +51,7 @@ using ::mediapipe::Timestamp;
|
|||
}
|
||||
|
||||
return MakePacket<Image>(std::move(imageFrame))
|
||||
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
|
||||
.At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
|
||||
}
|
||||
|
||||
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect {
|
||||
|
@ -59,9 +59,9 @@ using ::mediapipe::Timestamp;
|
|||
}
|
||||
|
||||
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect
|
||||
timestampMs:(NSInteger)timestampMs {
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
return MakePacket<NormalizedRect>(std::move(normalizedRect))
|
||||
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
|
||||
.At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
|
||||
gestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super initWithTimestampMs:timestampInMilliseconds];
|
||||
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
if (self) {
|
||||
_landmarks = landmarks;
|
||||
_worldLandmarks = worldLandmarks;
|
||||
|
|
|
@ -122,17 +122,17 @@ NS_SWIFT_NAME(ImageClassifier)
|
|||
* `MPPRunningModeVideo`.
|
||||
*
|
||||
* @param image The `MPPImage` on which image classification is to be performed.
|
||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
||||
* monotonically increasing.
|
||||
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||
* timestamps must be monotonically increasing.
|
||||
* @param error An optional error parameter populated when there is an error in performing image
|
||||
* classification on the input video frame.
|
||||
*
|
||||
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
|
||||
*/
|
||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(classify(videoFrame:timestampMs:));
|
||||
NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:));
|
||||
|
||||
/**
|
||||
* Performs image classification on the provided video frame of type `MPPImage` cropped to the
|
||||
|
@ -145,8 +145,8 @@ NS_SWIFT_NAME(ImageClassifier)
|
|||
*
|
||||
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
||||
* performed.
|
||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
||||
* monotonically increasing.
|
||||
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||
* timestamps must be monotonically increasing.
|
||||
* @param roi A `CGRect` specifying the region of interest within the video frame of type
|
||||
* `MPPImage`, on which image classification should be performed.
|
||||
* @param error An optional error parameter populated when there is an error in performing image
|
||||
|
@ -155,10 +155,10 @@ NS_SWIFT_NAME(ImageClassifier)
|
|||
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
|
||||
*/
|
||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:));
|
||||
NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:regionOfInterest:));
|
||||
|
||||
/**
|
||||
* Sends live stream image data of type `MPPImage` to perform image classification using the whole
|
||||
|
@ -172,16 +172,17 @@ NS_SWIFT_NAME(ImageClassifier)
|
|||
*
|
||||
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
||||
* performed.
|
||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
||||
* to the image classifier. The input timestamps must be monotonically increasing.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||
* image is sent to the image classifier. The input timestamps must be monotonically increasing.
|
||||
* @param error An optional error parameter populated when there is an error in performing image
|
||||
* classification on the input live stream image data.
|
||||
*
|
||||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||
*/
|
||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:));
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:));
|
||||
|
||||
/**
|
||||
* Sends live stream image data of type `MPPImage` to perform image classification, cropped to the
|
||||
|
@ -195,8 +196,8 @@ NS_SWIFT_NAME(ImageClassifier)
|
|||
*
|
||||
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
||||
* performed.
|
||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
||||
* to the image classifier. The input timestamps must be monotonically increasing.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||
* image is sent to the image classifier. The input timestamps must be monotonically increasing.
|
||||
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
|
||||
* of type `MPPImage`, on which image classification should be performed.
|
||||
* @param error An optional error parameter populated when there is an error in performing image
|
||||
|
@ -205,10 +206,10 @@ NS_SWIFT_NAME(ImageClassifier)
|
|||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||
*/
|
||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:));
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:regionOfInterest:));
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -149,7 +149,7 @@ static NSString *const kTaskGraphName =
|
|||
}
|
||||
|
||||
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
std::optional<NormalizedRect> rect =
|
||||
|
@ -162,14 +162,15 @@ static NSString *const kTaskGraphName =
|
|||
}
|
||||
|
||||
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
error:error];
|
||||
if (imagePacket.IsEmpty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
||||
timestampMs:timestampMs];
|
||||
Packet normalizedRectPacket =
|
||||
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
|
||||
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
|
||||
return inputPacketMap;
|
||||
|
@ -180,11 +181,11 @@ static NSString *const kTaskGraphName =
|
|||
}
|
||||
|
||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:roi
|
||||
error:error];
|
||||
if (!inputPacketMap.has_value()) {
|
||||
|
@ -204,20 +205,20 @@ static NSString *const kTaskGraphName =
|
|||
}
|
||||
|
||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error {
|
||||
return [self classifyVideoFrame:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:CGRectZero
|
||||
error:error];
|
||||
}
|
||||
|
||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:roi
|
||||
error:error];
|
||||
if (!inputPacketMap.has_value()) {
|
||||
|
@ -228,10 +229,10 @@ static NSString *const kTaskGraphName =
|
|||
}
|
||||
|
||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
error:(NSError **)error {
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error {
|
||||
return [self classifyAsyncImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:CGRectZero
|
||||
error:error];
|
||||
}
|
||||
|
|
|
@ -31,13 +31,13 @@ NS_SWIFT_NAME(ImageClassifierResult)
|
|||
*
|
||||
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
|
||||
* per classifier head.
|
||||
* @param timestampMs The timestamp for this result.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||
*
|
||||
* @return An instance of `MPPImageClassifierResult` initialized with the given
|
||||
* `MPPClassificationResult` and timestamp (in milliseconds).
|
||||
*/
|
||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||
timestampMs:(NSInteger)timestampMs;
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
@implementation MPPImageClassifierResult
|
||||
|
||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||
timestampMs:(NSInteger)timestampMs {
|
||||
self = [super initWithTimestampMs:timestampMs];
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
if (self) {
|
||||
_classificationResult = classificationResult;
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ using ::mediapipe::Packet;
|
|||
|
||||
return [[MPPImageClassifierResult alloc]
|
||||
initWithClassificationResult:classificationResult
|
||||
timestampMs:(NSInteger)(packet.Timestamp().Value() /
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
}
|
||||
|
||||
|
|
|
@ -36,13 +36,13 @@ NS_SWIFT_NAME(ObjectDetectionResult)
|
|||
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
|
||||
* expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width)
|
||||
* x [0,image_height)`, which are the dimensions of the underlying image data.
|
||||
* @param timestampMs The timestamp for this result.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||
*
|
||||
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
|
||||
* and timestamp (in milliseconds).
|
||||
*/
|
||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
||||
timestampMs:(NSInteger)timestampMs;
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
@implementation MPPObjectDetectionResult
|
||||
|
||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
||||
timestampMs:(NSInteger)timestampMs {
|
||||
self = [super initWithTimestampMs:timestampMs];
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||
if (self) {
|
||||
_detections = detections;
|
||||
}
|
||||
|
|
|
@ -138,8 +138,8 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
* `MPPRunningModeVideo`.
|
||||
*
|
||||
* @param image The `MPPImage` on which object detection is to be performed.
|
||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
||||
* monotonically increasing.
|
||||
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||
* timestamps must be monotonically increasing.
|
||||
* @param error An optional error parameter populated when there is an error in performing object
|
||||
* detection on the input image.
|
||||
*
|
||||
|
@ -149,9 +149,9 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
* image data.
|
||||
*/
|
||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detect(videoFrame:timestampMs:));
|
||||
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
|
||||
|
||||
/**
|
||||
* Performs object detection on the provided video frame of type `MPPImage` cropped to the
|
||||
|
@ -164,8 +164,8 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
*
|
||||
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
||||
* performed.
|
||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
||||
* monotonically increasing.
|
||||
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||
* timestamps must be monotonically increasing.
|
||||
* @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which
|
||||
* object detection should be performed.
|
||||
*
|
||||
|
@ -178,10 +178,10 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
* image data.
|
||||
*/
|
||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detect(videoFrame:timestampMs:regionOfInterest:));
|
||||
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:regionOfInterest:));
|
||||
|
||||
/**
|
||||
* Sends live stream image data of type `MPPImage` to perform object detection using the whole
|
||||
|
@ -195,16 +195,17 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
*
|
||||
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
||||
* performed.
|
||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
||||
* to the object detector. The input timestamps must be monotonically increasing.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||
* image is sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
* @param error An optional error parameter populated when there is an error in performing object
|
||||
* detection on the input live stream image data.
|
||||
*
|
||||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||
*/
|
||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
error:(NSError **)error NS_SWIFT_NAME(detectAsync(image:timestampMs:));
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
|
||||
|
||||
/**
|
||||
* Sends live stream image data of type `MPPImage` to perform object detection, cropped to the
|
||||
|
@ -218,8 +219,8 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
*
|
||||
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
||||
* performed.
|
||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
||||
* to the object detector. The input timestamps must be monotonically increasing.
|
||||
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||
* image is sent to the object detector. The input timestamps must be monotonically increasing.
|
||||
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
|
||||
* of type `MPPImage`, on which iobject detection should be performed.
|
||||
* @param error An optional error parameter populated when there is an error in performing object
|
||||
|
@ -228,10 +229,10 @@ NS_SWIFT_NAME(ObjectDetector)
|
|||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||
*/
|
||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detectAsync(image:timestampMs:regionOfInterest:));
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error
|
||||
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:regionOfInterest:));
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
|
|
|
@ -157,7 +157,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
|||
}
|
||||
|
||||
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
std::optional<NormalizedRect> rect =
|
||||
|
@ -170,14 +170,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
|||
}
|
||||
|
||||
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
error:error];
|
||||
if (imagePacket.IsEmpty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
||||
timestampMs:timestampMs];
|
||||
Packet normalizedRectPacket =
|
||||
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
||||
timestampInMilliseconds:timestampInMilliseconds];
|
||||
|
||||
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
|
||||
return inputPacketMap;
|
||||
|
@ -188,11 +189,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
|||
}
|
||||
|
||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:roi
|
||||
error:error];
|
||||
if (!inputPacketMap.has_value()) {
|
||||
|
@ -212,20 +213,20 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
|||
}
|
||||
|
||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error {
|
||||
return [self detectInVideoFrame:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:CGRectZero
|
||||
error:error];
|
||||
}
|
||||
|
||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
regionOfInterest:(CGRect)roi
|
||||
error:(NSError **)error {
|
||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:roi
|
||||
error:error];
|
||||
if (!inputPacketMap.has_value()) {
|
||||
|
@ -236,10 +237,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
|||
}
|
||||
|
||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||
timestampMs:(NSInteger)timestampMs
|
||||
error:(NSError **)error {
|
||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||
error:(NSError **)error {
|
||||
return [self detectAsyncInImage:image
|
||||
timestampMs:timestampMs
|
||||
timestampInMilliseconds:timestampInMilliseconds
|
||||
regionOfInterest:CGRectZero
|
||||
error:error];
|
||||
}
|
||||
|
|
|
@ -38,8 +38,9 @@ using ::mediapipe::Packet;
|
|||
}
|
||||
|
||||
return [[MPPObjectDetectionResult alloc]
|
||||
initWithDetections:detections
|
||||
timestampMs:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)];
|
||||
initWithDetections:detections
|
||||
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||
kMicroSecondsPerMilliSecond)];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -198,9 +198,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
||||
* size of the stylized output is based the model output size and can be smaller than the input
|
||||
* image.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created
|
||||
|
@ -220,9 +220,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
||||
* the stylized output image size is the smaller of the model output size and the size of the
|
||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
|
@ -256,9 +256,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
||||
* size of the stylized output is based the model output size and can be smaller than the input
|
||||
* image.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||
|
@ -281,9 +281,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
||||
* the stylized output image size is the smaller of the model output size and the size of the
|
||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
|
@ -320,9 +320,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
||||
* size of the stylized output is based the model output size and can be smaller than the input
|
||||
* image.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
|
@ -346,9 +346,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
||||
* the stylized output image size is the smaller of the model output size and the size of the
|
||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}. *
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
|
@ -387,9 +387,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
||||
* size of the stylized output is based the model output size and can be smaller than the input
|
||||
* image.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
|
@ -414,9 +414,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
||||
* the stylized output image size is the smaller of the model output size and the size of the
|
||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param timestampMs the input timestamp (in milliseconds).
|
||||
|
@ -445,9 +445,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
*
|
||||
* <p>{@link FaceStylizer} supports the following color space types:
|
||||
*
|
||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
||||
* size of the stylized output is based the model output * size and can be smaller than the input
|
||||
* image.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
|
@ -475,9 +475,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
|||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||
* </ul>
|
||||
*
|
||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
||||
* the stylized output image size is the smaller of the model output size and the size of the
|
||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
||||
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||
* is detected on the input image, returns {@code Optional.empty()}.
|
||||
*
|
||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||
|
|
|
@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
"IMAGE:" + IMAGE_IN_STREAM_NAME,
|
||||
"ROI:" + ROI_IN_STREAM_NAME,
|
||||
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||
private static final List<String> OUTPUT_STREAMS =
|
||||
Collections.unmodifiableList(
|
||||
Arrays.asList(
|
||||
"GROUPED_SEGMENTATION:segmented_mask_out",
|
||||
"IMAGE:image_out",
|
||||
"SEGMENTATION:0:segmentation"));
|
||||
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
||||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
||||
private static final int IMAGE_OUT_STREAM_INDEX = 0;
|
||||
private static final String TASK_GRAPH_NAME =
|
||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
|
@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
*/
|
||||
public static InteractiveSegmenter createFromOptions(
|
||||
Context context, InteractiveSegmenterOptions segmenterOptions) {
|
||||
if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
|
||||
throw new IllegalArgumentException(
|
||||
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
||||
}
|
||||
List<String> outputStreams = new ArrayList<>();
|
||||
outputStreams.add("IMAGE:image_out");
|
||||
if (segmenterOptions.outputConfidenceMasks()) {
|
||||
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
|
||||
}
|
||||
final int confidenceMasksOutStreamIndex = outputStreams.size() - 1;
|
||||
if (segmenterOptions.outputCategoryMask()) {
|
||||
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||
}
|
||||
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||
|
||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||
handler.setOutputPacketConverter(
|
||||
|
@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
@Override
|
||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||
throws MediaPipeException {
|
||||
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
||||
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
|
||||
return ImageSegmenterResult.create(
|
||||
Optional.empty(),
|
||||
Optional.empty(),
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
||||
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
|
||||
}
|
||||
List<MPImage> segmentedMasks = new ArrayList<>();
|
||||
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
||||
int imageFormat =
|
||||
segmenterOptions.outputType()
|
||||
== InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK
|
||||
? MPImage.IMAGE_FORMAT_VEC32F1
|
||||
: MPImage.IMAGE_FORMAT_ALPHA;
|
||||
int imageListSize =
|
||||
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
|
||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
|
||||
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
|
||||
if (!segmenterOptions.resultListener().isPresent()) {
|
||||
for (int i = 0; i < imageListSize; i++) {
|
||||
buffersArray[i] =
|
||||
ByteBuffer.allocateDirect(
|
||||
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
|
||||
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
|
||||
// memory.
|
||||
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||
Optional<List<MPImage>> confidenceMasks = Optional.empty();
|
||||
if (segmenterOptions.outputConfidenceMasks()) {
|
||||
confidenceMasks = Optional.of(new ArrayList<>());
|
||||
int width =
|
||||
PacketGetter.getImageWidthFromImageList(
|
||||
packets.get(confidenceMasksOutStreamIndex));
|
||||
int height =
|
||||
PacketGetter.getImageHeightFromImageList(
|
||||
packets.get(confidenceMasksOutStreamIndex));
|
||||
int imageListSize =
|
||||
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
|
||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||
// confidence masks are float type image.
|
||||
final int numBytes = 4;
|
||||
if (copyImage) {
|
||||
for (int i = 0; i < imageListSize; i++) {
|
||||
buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes);
|
||||
}
|
||||
}
|
||||
if (!PacketGetter.getImageList(
|
||||
packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting confidence masks.");
|
||||
}
|
||||
for (ByteBuffer buffer : buffersArray) {
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
|
||||
confidenceMasks.get().add(builder.build());
|
||||
}
|
||||
}
|
||||
if (!PacketGetter.getImageList(
|
||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
|
||||
buffersArray,
|
||||
!segmenterOptions.resultListener().isPresent())) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting segmented masks. It usually results from incorrect"
|
||||
+ " options of unsupported OutputType of given model.");
|
||||
}
|
||||
for (ByteBuffer buffer : buffersArray) {
|
||||
Optional<MPImage> categoryMask = Optional.empty();
|
||||
if (segmenterOptions.outputCategoryMask()) {
|
||||
int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
|
||||
int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
|
||||
ByteBuffer buffer;
|
||||
if (copyImage) {
|
||||
buffer = ByteBuffer.allocateDirect(width * height);
|
||||
if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
|
||||
throw new MediaPipeException(
|
||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||
"There is an error getting category mask.");
|
||||
}
|
||||
} else {
|
||||
buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
|
||||
}
|
||||
ByteBufferImageBuilder builder =
|
||||
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
|
||||
segmentedMasks.add(builder.build());
|
||||
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||
categoryMask = Optional.of(builder.build());
|
||||
}
|
||||
|
||||
return ImageSegmenterResult.create(
|
||||
Optional.of(segmentedMasks),
|
||||
Optional.empty(),
|
||||
confidenceMasks,
|
||||
categoryMask,
|
||||
BaseVisionTaskApi.generateResultTimestampMs(
|
||||
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
||||
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
.setOutputStreams(outputStreams)
|
||||
.setTaskOptions(segmenterOptions)
|
||||
.setEnableFlowLimiting(false)
|
||||
.build(),
|
||||
|
@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
/** Sets the base options for the image segmenter task. */
|
||||
public abstract Builder setBaseOptions(BaseOptions value);
|
||||
|
||||
/** The output type from image segmenter. */
|
||||
public abstract Builder setOutputType(OutputType value);
|
||||
/** Sets whether to output confidence masks. Default to true. */
|
||||
public abstract Builder setOutputConfidenceMasks(boolean value);
|
||||
|
||||
/** Sets whether to output category mask. Default to false. */
|
||||
public abstract Builder setOutputCategoryMask(boolean value);
|
||||
|
||||
/**
|
||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
||||
|
@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
|
||||
abstract BaseOptions baseOptions();
|
||||
|
||||
abstract OutputType outputType();
|
||||
abstract boolean outputConfidenceMasks();
|
||||
|
||||
abstract boolean outputCategoryMask();
|
||||
|
||||
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
|
||||
|
||||
abstract Optional<ErrorListener> errorListener();
|
||||
|
||||
/** The output type of segmentation results. */
|
||||
public enum OutputType {
|
||||
// Gives a single output mask where each pixel represents the class which
|
||||
// the pixel in the original image was predicted to belong to.
|
||||
CATEGORY_MASK,
|
||||
// Gives a list of output masks where, for each mask, each pixel represents
|
||||
// the prediction confidence, usually in the [0, 1] range.
|
||||
CONFIDENCE_MASK
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
|
||||
.setOutputType(OutputType.CATEGORY_MASK);
|
||||
.setOutputConfidenceMasks(true)
|
||||
.setOutputCategoryMask(false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
|||
|
||||
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
||||
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
||||
if (outputType() == OutputType.CONFIDENCE_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
|
||||
} else if (outputType() == OutputType.CATEGORY_MASK) {
|
||||
segmenterOptionsBuilder.setOutputType(
|
||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
||||
}
|
||||
|
||||
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
||||
return CalculatorOptions.newBuilder()
|
||||
.setExtension(
|
||||
|
|
|
@ -234,8 +234,8 @@ public class FaceStylizerTest {
|
|||
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage);
|
||||
MPImage stylizedImage = actualResult.stylizedImage().get();
|
||||
assertThat(stylizedImage).isNotNull();
|
||||
assertThat(stylizedImage.getWidth()).isEqualTo(83);
|
||||
assertThat(stylizedImage.getHeight()).isEqualTo(83);
|
||||
assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
|
||||
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -53,18 +53,15 @@ public class InteractiveSegmenterTest {
|
|||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK)
|
||||
.setOutputConfidenceMasks(false)
|
||||
.setOutputCategoryMask(true)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
MPImage image = getImageFromAsset(inputImageName);
|
||||
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||
// TODO update to correct category mask output.
|
||||
// After InteractiveSegmenter updated according to (b/276519300), update this to use
|
||||
// categoryMask field instead of confidenceMasks.
|
||||
List<MPImage> segmentations = actualResult.confidenceMasks().get();
|
||||
assertThat(segmentations.size()).isEqualTo(1);
|
||||
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -75,15 +72,17 @@ public class InteractiveSegmenterTest {
|
|||
InteractiveSegmenterOptions options =
|
||||
InteractiveSegmenterOptions.builder()
|
||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||
.setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
||||
.setOutputConfidenceMasks(true)
|
||||
.setOutputCategoryMask(false)
|
||||
.build();
|
||||
InteractiveSegmenter imageSegmenter =
|
||||
InteractiveSegmenter.createFromOptions(
|
||||
ApplicationProvider.getApplicationContext(), options);
|
||||
ImageSegmenterResult actualResult =
|
||||
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||
List<MPImage> segmentations = actualResult.confidenceMasks().get();
|
||||
assertThat(segmentations.size()).isEqualTo(2);
|
||||
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
|
||||
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
|
||||
assertThat(confidenceMasks.size()).isEqualTo(2);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -204,6 +204,11 @@ This can be useful for resetting a stateful task graph to process new data.
|
|||
Raises:
|
||||
RuntimeError: The underlying medipaipe graph fails to reset and restart.
|
||||
)doc");
|
||||
|
||||
task_runner.def(
|
||||
"get_graph_config",
|
||||
[](TaskRunner* self) { return self->GetGraphConfig(); },
|
||||
R"doc(Returns the canonicalized CalculatorGraphConfig of the underlying graph.)doc");
|
||||
}
|
||||
|
||||
} // namespace python
|
||||
|
|
|
@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
|
|||
|
||||
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
|
||||
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
|
||||
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||
# Tolerance for embedding vector coordinate values.
|
||||
_EPSILON = 1e-4
|
||||
|
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
|
|||
16,
|
||||
(0.549632, 0.552879),
|
||||
),
|
||||
(
|
||||
False,
|
||||
False,
|
||||
_USE_MODEL_FILE,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.851961,
|
||||
100,
|
||||
(1.422951, 1.404664),
|
||||
),
|
||||
(
|
||||
True,
|
||||
False,
|
||||
_USE_MODEL_FILE,
|
||||
ModelFileType.FILE_CONTENT,
|
||||
0.851961,
|
||||
100,
|
||||
(0.127049, 0.125416),
|
||||
),
|
||||
)
|
||||
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
|
||||
expected_similarity, expected_size, expected_first_values):
|
||||
|
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
|
|||
16,
|
||||
(0.549632, 0.552879),
|
||||
),
|
||||
(
|
||||
False,
|
||||
False,
|
||||
_USE_MODEL_FILE,
|
||||
ModelFileType.FILE_NAME,
|
||||
0.851961,
|
||||
100,
|
||||
(1.422951, 1.404664),
|
||||
),
|
||||
(
|
||||
True,
|
||||
False,
|
||||
_USE_MODEL_FILE,
|
||||
ModelFileType.FILE_CONTENT,
|
||||
0.851961,
|
||||
100,
|
||||
(0.127049, 0.125416),
|
||||
),
|
||||
)
|
||||
def test_embed_in_context(self, l2_normalize, quantize, model_name,
|
||||
model_file_type, expected_similarity, expected_size,
|
||||
|
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
|
|||
@parameterized.parameters(
|
||||
# TODO: The similarity should likely be lower
|
||||
(_BERT_MODEL_FILE, 0.980880),
|
||||
(_USE_MODEL_FILE, 0.780334),
|
||||
)
|
||||
def test_embed_with_different_themes(self, model_file, expected_similarity):
|
||||
# Creates embedder.
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
import enum
|
||||
import os
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest
|
||||
|
@ -30,11 +29,10 @@ from mediapipe.tasks.python.test import test_utils
|
|||
from mediapipe.tasks.python.vision import image_segmenter
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
ImageSegmenterResult = image_segmenter.ImageSegmenterResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Image = image_module.Image
|
||||
_ImageFormat = image_frame.ImageFormat
|
||||
_OutputType = image_segmenter.ImageSegmenterOptions.OutputType
|
||||
_Activation = image_segmenter.ImageSegmenterOptions.Activation
|
||||
_ImageSegmenter = image_segmenter.ImageSegmenter
|
||||
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||
|
@ -42,11 +40,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
|||
_MODEL_FILE = 'deeplabv3.tflite'
|
||||
_IMAGE_FILE = 'segmentation_input_rotation0.jpg'
|
||||
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
|
||||
_CAT_IMAGE = 'cat.jpg'
|
||||
_CAT_MASK = 'cat_mask.jpg'
|
||||
_MASK_MAGNIFICATION_FACTOR = 10
|
||||
_MASK_SIMILARITY_THRESHOLD = 0.98
|
||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||
|
||||
|
||||
def _calculate_soft_iou(m1, m2):
|
||||
intersection_sum = np.sum(m1 * m2)
|
||||
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
|
||||
|
||||
if union_sum > 0:
|
||||
return intersection_sum / union_sum
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
|
||||
actual_mask = actual_mask.numpy_view()
|
||||
expected_mask = expected_mask.numpy_view() / 255.0
|
||||
|
||||
return (
|
||||
actual_mask.shape == expected_mask.shape
|
||||
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
|
||||
)
|
||||
|
||||
|
||||
def _similar_to_uint8_mask(actual_mask, expected_mask):
|
||||
actual_mask_pixels = actual_mask.numpy_view().flatten()
|
||||
expected_mask_pixels = expected_mask.numpy_view().flatten()
|
||||
|
@ -56,8 +76,9 @@ def _similar_to_uint8_mask(actual_mask, expected_mask):
|
|||
|
||||
for index in range(num_pixels):
|
||||
consistent_pixels += (
|
||||
actual_mask_pixels[index] *
|
||||
_MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index])
|
||||
actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR
|
||||
== expected_mask_pixels[index]
|
||||
)
|
||||
|
||||
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
|
||||
|
||||
|
@ -73,16 +94,27 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
super().setUp()
|
||||
# Load the test input image.
|
||||
self.test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
|
||||
)
|
||||
# Loads ground truth segmentation file.
|
||||
gt_segmentation_data = cv2.imread(
|
||||
test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)),
|
||||
cv2.IMREAD_GRAYSCALE)
|
||||
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)
|
||||
),
|
||||
cv2.IMREAD_GRAYSCALE,
|
||||
)
|
||||
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||
self.model_path = test_utils.get_test_data_path(
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||
)
|
||||
|
||||
def _load_segmentation_mask(self, file_path: str):
|
||||
# Loads ground truth segmentation file.
|
||||
gt_segmentation_data = cv2.imread(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
|
||||
cv2.IMREAD_GRAYSCALE,
|
||||
)
|
||||
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||
|
||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||
# Creates with default option and valid model file successfully.
|
||||
|
@ -98,9 +130,11 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
|
||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||
):
|
||||
base_options = _BaseOptions(
|
||||
model_asset_path='/path/to/invalid/model.tflite')
|
||||
model_asset_path='/path/to/invalid/model.tflite'
|
||||
)
|
||||
options = _ImageSegmenterOptions(base_options=base_options)
|
||||
_ImageSegmenter.create_from_options(options)
|
||||
|
||||
|
@ -112,8 +146,9 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||
|
||||
@parameterized.parameters((ModelFileType.FILE_NAME,),
|
||||
(ModelFileType.FILE_CONTENT,))
|
||||
@parameterized.parameters(
|
||||
(ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,)
|
||||
)
|
||||
def test_segment_succeeds_with_category_mask(self, model_file_type):
|
||||
# Creates segmenter.
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
|
@ -127,22 +162,27 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
||||
base_options=base_options,
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
)
|
||||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
|
||||
# Performs image segmentation on the input.
|
||||
category_masks = segmenter.segment(self.test_image)
|
||||
self.assertLen(category_masks, 1)
|
||||
category_mask = category_masks[0]
|
||||
segmentation_result = segmenter.segment(self.test_image)
|
||||
category_mask = segmentation_result.category_mask
|
||||
result_pixels = category_mask.numpy_view().flatten()
|
||||
|
||||
# Check if data type of `category_mask` is correct.
|
||||
self.assertEqual(result_pixels.dtype, np.uint8)
|
||||
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
|
||||
f'Number of pixels in the candidate mask differing from that of the '
|
||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
||||
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of the'
|
||||
f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||
),
|
||||
)
|
||||
|
||||
# Closes the segmenter explicitly when the segmenter is not used in
|
||||
# a context.
|
||||
|
@ -152,74 +192,46 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
# Creates segmenter.
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
|
||||
# Run segmentation on the model in CATEGORY_MASK mode.
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
||||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
category_masks = segmenter.segment(self.test_image)
|
||||
category_mask = category_masks[0].numpy_view()
|
||||
# Load the cat image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options,
|
||||
output_type=_OutputType.CONFIDENCE_MASK,
|
||||
activation=_Activation.SOFTMAX)
|
||||
segmenter = _ImageSegmenter.create_from_options(options)
|
||||
confidence_masks = segmenter.segment(self.test_image)
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks, 21,
|
||||
'Number of confidence masks must match with number of categories.')
|
||||
|
||||
# Gather the confidence masks in a single array `confidence_mask_array`.
|
||||
confidence_mask_array = np.array(
|
||||
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
|
||||
|
||||
# Check if data type of `confidence_masks` are correct.
|
||||
self.assertEqual(confidence_mask_array.dtype, np.float32)
|
||||
|
||||
# Compute the category mask from the created confidence mask.
|
||||
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
|
||||
self.assertListEqual(
|
||||
calculated_category_mask.tolist(), category_mask.tolist(),
|
||||
'Confidence mask does not match with the category mask.')
|
||||
|
||||
# Closes the segmenter explicitly when the segmenter is not used in
|
||||
# a context.
|
||||
segmenter.close()
|
||||
|
||||
@parameterized.parameters((ModelFileType.FILE_NAME),
|
||||
(ModelFileType.FILE_CONTENT))
|
||||
def test_segment_in_context(self, model_file_type):
|
||||
if model_file_type is ModelFileType.FILE_NAME:
|
||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
||||
with open(self.model_path, 'rb') as f:
|
||||
model_contents = f.read()
|
||||
base_options = _BaseOptions(model_asset_buffer=model_contents)
|
||||
else:
|
||||
# Should never happen
|
||||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
# Performs image segmentation on the input.
|
||||
category_masks = segmenter.segment(self.test_image)
|
||||
self.assertLen(category_masks, 1)
|
||||
segmentation_result = segmenter.segment(test_image)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
21,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
|
||||
f'Number of pixels in the candidate mask differing from that of the '
|
||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
||||
_similar_to_float_mask(
|
||||
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
def test_missing_result_callback(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback must be provided'):
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback must be provided'
|
||||
):
|
||||
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
||||
pass
|
||||
|
||||
|
@ -228,130 +240,236 @@ class ImageSegmenterTest(parameterized.TestCase):
|
|||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=running_mode,
|
||||
result_callback=mock.MagicMock())
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'result callback should not be provided'):
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'result callback should not be provided'
|
||||
):
|
||||
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
||||
pass
|
||||
|
||||
def test_calling_segment_for_video_in_image_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
segmenter.segment_for_video(self.test_image, 0)
|
||||
|
||||
def test_calling_segment_async_in_image_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.IMAGE)
|
||||
running_mode=_RUNNING_MODE.IMAGE,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
segmenter.segment_async(self.test_image, 0)
|
||||
|
||||
def test_calling_segment_in_video_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
segmenter.segment(self.test_image)
|
||||
|
||||
def test_calling_segment_async_in_video_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the live stream mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the live stream mode'
|
||||
):
|
||||
segmenter.segment_async(self.test_image, 0)
|
||||
|
||||
def test_segment_for_video_with_out_of_order_timestamp(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
unused_result = segmenter.segment_for_video(self.test_image, 1)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
segmenter.segment_for_video(self.test_image, 0)
|
||||
|
||||
def test_segment_for_video(self):
|
||||
def test_segment_for_video_in_category_mask_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
output_type=_OutputType.CATEGORY_MASK,
|
||||
running_mode=_RUNNING_MODE.VIDEO)
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
category_masks = segmenter.segment_for_video(self.test_image, timestamp)
|
||||
self.assertLen(category_masks, 1)
|
||||
segmentation_result = segmenter.segment_for_video(
|
||||
self.test_image, timestamp
|
||||
)
|
||||
category_mask = segmentation_result.category_mask
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
|
||||
f'Number of pixels in the candidate mask differing from that of the '
|
||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
||||
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of'
|
||||
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||
),
|
||||
)
|
||||
|
||||
def test_segment_for_video_in_confidence_mask_mode(self):
|
||||
# Load the cat image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.VIDEO,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
segmentation_result = segmenter.segment_for_video(test_image, timestamp)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
21,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||
self.assertTrue(
|
||||
_similar_to_float_mask(
|
||||
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||
)
|
||||
)
|
||||
|
||||
def test_calling_segment_in_live_stream_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the image mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the image mode'
|
||||
):
|
||||
segmenter.segment(self.test_image)
|
||||
|
||||
def test_calling_segment_for_video_in_live_stream_mode(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r'not initialized with the video mode'):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'not initialized with the video mode'
|
||||
):
|
||||
segmenter.segment_for_video(self.test_image, 0)
|
||||
|
||||
def test_segment_async_calls_with_illegal_timestamp(self):
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=mock.MagicMock())
|
||||
result_callback=mock.MagicMock(),
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
segmenter.segment_async(self.test_image, 100)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
||||
ValueError, r'Input timestamp must be monotonically increasing'
|
||||
):
|
||||
segmenter.segment_async(self.test_image, 0)
|
||||
|
||||
def test_segment_async_calls(self):
|
||||
def test_segment_async_calls_in_category_mask_mode(self):
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(result: List[image_module.Image], output_image: _Image,
|
||||
timestamp_ms: int):
|
||||
def check_result(
|
||||
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
# Get the output category mask.
|
||||
category_mask = result[0]
|
||||
category_mask = result.category_mask
|
||||
self.assertEqual(output_image.width, self.test_image.width)
|
||||
self.assertEqual(output_image.height, self.test_image.height)
|
||||
self.assertEqual(output_image.width, self.test_seg_image.width)
|
||||
self.assertEqual(output_image.height, self.test_seg_image.height)
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||
f'Number of pixels in the candidate mask differing from that of the '
|
||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of'
|
||||
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||
),
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
output_type=_OutputType.CATEGORY_MASK,
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
result_callback=check_result)
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
segmenter.segment_async(self.test_image, timestamp)
|
||||
|
||||
def test_segment_async_calls_in_confidence_mask_mode(self):
|
||||
# Load the cat image.
|
||||
test_image = _Image.create_from_file(
|
||||
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||
)
|
||||
|
||||
# Loads ground truth segmentation file.
|
||||
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||
observed_timestamp_ms = -1
|
||||
|
||||
def check_result(
|
||||
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
|
||||
):
|
||||
# Get the output category mask.
|
||||
confidence_masks = result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
confidence_masks,
|
||||
21,
|
||||
'Number of confidence masks must match with number of categories.',
|
||||
)
|
||||
self.assertEqual(output_image.width, test_image.width)
|
||||
self.assertEqual(output_image.height, test_image.height)
|
||||
self.assertTrue(
|
||||
_similar_to_float_mask(
|
||||
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||
)
|
||||
)
|
||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||
self.observed_timestamp_ms = timestamp_ms
|
||||
|
||||
options = _ImageSegmenterOptions(
|
||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
result_callback=check_result,
|
||||
)
|
||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||
for timestamp in range(0, 300, 30):
|
||||
segmenter.segment_async(test_image, timestamp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils
|
|||
from mediapipe.tasks.python.vision import interactive_segmenter
|
||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
|
||||
InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_Image = image_module.Image
|
||||
_ImageFormat = image_frame.ImageFormat
|
||||
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint
|
||||
_Rect = rect.Rect
|
||||
_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType
|
||||
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
|
||||
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
|
||||
_RegionOfInterest = interactive_segmenter.RegionOfInterest
|
||||
|
@ -200,15 +200,16 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
|||
raise ValueError('model_file_type is invalid.')
|
||||
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK
|
||||
base_options=base_options,
|
||||
output_category_mask=True,
|
||||
output_confidence_masks=False,
|
||||
)
|
||||
segmenter = _InteractiveSegmenter.create_from_options(options)
|
||||
|
||||
# Performs image segmentation on the input.
|
||||
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
|
||||
category_masks = segmenter.segment(self.test_image, roi)
|
||||
self.assertLen(category_masks, 1)
|
||||
category_mask = category_masks[0]
|
||||
segmentation_result = segmenter.segment(self.test_image, roi)
|
||||
category_mask = segmentation_result.category_mask
|
||||
result_pixels = category_mask.numpy_view().flatten()
|
||||
|
||||
# Check if data type of `category_mask` is correct.
|
||||
|
@ -219,7 +220,7 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
|||
|
||||
self.assertTrue(
|
||||
_similar_to_uint8_mask(
|
||||
category_masks[0], test_seg_image, similarity_threshold
|
||||
category_mask, test_seg_image, similarity_threshold
|
||||
),
|
||||
(
|
||||
'Number of pixels in the candidate mask differing from that of the'
|
||||
|
@ -254,12 +255,15 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
|||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
|
||||
base_options=base_options,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||
# Perform segmentation
|
||||
confidence_masks = segmenter.segment(self.test_image, roi)
|
||||
segmentation_result = segmenter.segment(self.test_image, roi)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
|
@ -287,15 +291,18 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
|||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
|
||||
base_options=base_options,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||
# Perform segmentation
|
||||
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||
confidence_masks = segmenter.segment(
|
||||
segmentation_result = segmenter.segment(
|
||||
self.test_image, roi, image_processing_options
|
||||
)
|
||||
confidence_masks = segmentation_result.confidence_masks
|
||||
|
||||
# Check if confidence mask shape is correct.
|
||||
self.assertLen(
|
||||
|
@ -314,7 +321,9 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
|||
|
||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||
options = _InteractiveSegmenterOptions(
|
||||
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
|
||||
base_options=base_options,
|
||||
output_category_mask=False,
|
||||
output_confidence_masks=True,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
|
|
|
@ -32,6 +32,7 @@ FaceDetectorResult = face_detector.FaceDetectorResult
|
|||
FaceLandmarker = face_landmarker.FaceLandmarker
|
||||
FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions
|
||||
FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult
|
||||
FaceLandmarksConnections = face_landmarker.FaceLandmarksConnections
|
||||
FaceStylizer = face_stylizer.FaceStylizer
|
||||
FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
||||
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||
|
|
|
@ -208,6 +208,11 @@ class BaseVisionTaskApi(object):
|
|||
"""
|
||||
self._runner.close()
|
||||
|
||||
def get_graph_config(self) -> calculator_pb2.CalculatorGraphConfig:
|
||||
"""Returns the canonicalized CalculatorGraphConfig of the underlying graph.
|
||||
"""
|
||||
return self._runner.get_graph_config()
|
||||
|
||||
def __enter__(self):
|
||||
"""Return `self` upon entering the runtime context."""
|
||||
return self
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -176,16 +176,13 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
Only use this method when the FaceStylizer is created with the image
|
||||
running mode.
|
||||
|
||||
To ensure that the output image has reasonable quality, the stylized output
|
||||
image size is the smaller of the model output size and the size of the
|
||||
`region_of_interest` specified in `image_processing_options`.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
The stylized image of the most visible face. None if no face is detected
|
||||
The stylized image of the most visible face. The stylized output image
|
||||
size is the same as the model output size. None if no face is detected
|
||||
on the input image.
|
||||
|
||||
Raises:
|
||||
|
@ -217,17 +214,14 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
milliseconds) along with the video frame. The input timestamps should be
|
||||
monotonically increasing for adjacent calls of this method.
|
||||
|
||||
To ensure that the output image has reasonable quality, the stylized output
|
||||
image size is the smaller of the model output size and the size of the
|
||||
`region_of_interest` specified in `image_processing_options`.
|
||||
|
||||
Args:
|
||||
image: MediaPipe Image.
|
||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||
image_processing_options: Options for image processing.
|
||||
|
||||
Returns:
|
||||
The stylized image of the most visible face. None if no face is detected
|
||||
The stylized image of the most visible face. The stylized output image
|
||||
size is the same as the model output size. None if no face is detected
|
||||
on the input image.
|
||||
|
||||
Raises:
|
||||
|
@ -266,12 +260,9 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
|||
images if needed. In other words, it's not guaranteed to have output per
|
||||
input image.
|
||||
|
||||
To ensure that the stylized image has reasonable quality, the stylized
|
||||
output image size is the smaller of the model output size and the size of
|
||||
the `region_of_interest` specified in `image_processing_options`.
|
||||
|
||||
The `result_callback` provides:
|
||||
- The stylized image of the most visible face. None if no face is detected
|
||||
- The stylized image of the most visible face. The stylized output image
|
||||
size is the same as the model output size. None if no face is detected
|
||||
on the input image.
|
||||
- The input image that the face stylizer runs on.
|
||||
- The input timestamp in milliseconds.
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
"""MediaPipe image segmenter task."""
|
||||
|
||||
import dataclasses
|
||||
import enum
|
||||
from typing import Callable, List, Mapping, Optional
|
||||
|
||||
from mediapipe.python import packet_creator
|
||||
|
@ -31,7 +30,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
|||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||
|
||||
ImageSegmenterResult = List[image_module.Image]
|
||||
_NormalizedRect = rect.NormalizedRect
|
||||
_BaseOptions = base_options_module.BaseOptions
|
||||
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
|
||||
|
@ -42,8 +40,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
|||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
|
||||
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
||||
_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
|
||||
_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
|
||||
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
|
||||
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_IMAGE_TAG = 'IMAGE'
|
||||
|
@ -53,6 +53,21 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
|
|||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageSegmenterResult:
|
||||
"""Output result of ImageSegmenter.
|
||||
|
||||
confidence_masks: multiple masks of float image where, for each mask, each
|
||||
pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||
|
||||
category_mask: a category mask of uint8 image where each pixel represents the
|
||||
class which the pixel in the original image was predicted to belong to.
|
||||
"""
|
||||
|
||||
confidence_masks: Optional[List[image_module.Image]] = None
|
||||
category_mask: Optional[image_module.Image] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImageSegmenterOptions:
|
||||
"""Options for the image segmenter task.
|
||||
|
@ -64,28 +79,17 @@ class ImageSegmenterOptions:
|
|||
objects on single image inputs. 2) The video mode for segmenting objects
|
||||
on the decoded frames of a video. 3) The live stream mode for segmenting
|
||||
objects on a live stream of input data, such as from camera.
|
||||
output_type: The output mask type allows specifying the type of
|
||||
post-processing to perform on the raw model results.
|
||||
activation: Activation function to apply to input tensor.
|
||||
output_confidence_masks: Whether to output confidence masks.
|
||||
output_category_mask: Whether to output category mask.
|
||||
result_callback: The user-defined result callback for processing live stream
|
||||
data. The result callback should only be specified when the running mode
|
||||
is set to the live stream mode.
|
||||
"""
|
||||
|
||||
class OutputType(enum.Enum):
|
||||
UNSPECIFIED = 0
|
||||
CATEGORY_MASK = 1
|
||||
CONFIDENCE_MASK = 2
|
||||
|
||||
class Activation(enum.Enum):
|
||||
NONE = 0
|
||||
SIGMOID = 1
|
||||
SOFTMAX = 2
|
||||
|
||||
base_options: _BaseOptions
|
||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
||||
activation: Optional[Activation] = Activation.NONE
|
||||
output_confidence_masks: bool = True
|
||||
output_category_mask: bool = False
|
||||
result_callback: Optional[
|
||||
Callable[[ImageSegmenterResult, image_module.Image, int], None]
|
||||
] = None
|
||||
|
@ -97,9 +101,7 @@ class ImageSegmenterOptions:
|
|||
base_options_proto.use_stream_mode = (
|
||||
False if self.running_mode == _RunningMode.IMAGE else True
|
||||
)
|
||||
segmenter_options_proto = _SegmenterOptionsProto(
|
||||
output_type=self.output_type.value, activation=self.activation.value
|
||||
)
|
||||
segmenter_options_proto = _SegmenterOptionsProto()
|
||||
return _ImageSegmenterGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
segmenter_options=segmenter_options_proto,
|
||||
|
@ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||
return
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
||||
)
|
||||
|
||||
segmentation_result = ImageSegmenterResult()
|
||||
|
||||
if options.output_confidence_masks:
|
||||
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||
)
|
||||
|
||||
if options.output_category_mask:
|
||||
segmentation_result.category_mask = packet_getter.get_image(
|
||||
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||
)
|
||||
|
||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
|
||||
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||
options.result_callback(
|
||||
segmentation_result,
|
||||
image,
|
||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||
)
|
||||
|
||||
output_streams = [
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
]
|
||||
|
||||
if options.output_confidence_masks:
|
||||
output_streams.append(
|
||||
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
|
||||
)
|
||||
|
||||
if options.output_category_mask:
|
||||
output_streams.append(
|
||||
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=output_streams,
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
|
@ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
||||
)
|
||||
segmentation_result = ImageSegmenterResult()
|
||||
|
||||
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
|
||||
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||
)
|
||||
|
||||
if _CATEGORY_MASK_STREAM_NAME in output_packets:
|
||||
segmentation_result.category_mask = packet_getter.get_image(
|
||||
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||
)
|
||||
|
||||
return segmentation_result
|
||||
|
||||
def segment_for_video(
|
||||
|
@ -285,9 +317,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
normalized_rect.to_pb2()
|
||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||
})
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
||||
)
|
||||
segmentation_result = ImageSegmenterResult()
|
||||
|
||||
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
|
||||
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||
)
|
||||
|
||||
if _CATEGORY_MASK_STREAM_NAME in output_packets:
|
||||
segmentation_result.category_mask = packet_getter.get_image(
|
||||
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||
)
|
||||
|
||||
return segmentation_result
|
||||
|
||||
def segment_async(
|
||||
|
|
|
@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
|||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||
_TaskInfo = task_info_module.TaskInfo
|
||||
|
||||
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
|
||||
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
||||
_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
|
||||
_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
|
||||
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
|
||||
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
|
||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||
_ROI_STREAM_NAME = 'roi_in'
|
||||
|
@ -55,32 +57,41 @@ _TASK_GRAPH_NAME = (
|
|||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InteractiveSegmenterResult:
|
||||
"""Output result of InteractiveSegmenter.
|
||||
|
||||
confidence_masks: multiple masks of float image where, for each mask, each
|
||||
pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||
|
||||
category_mask: a category mask of uint8 image where each pixel represents the
|
||||
class which the pixel in the original image was predicted to belong to.
|
||||
"""
|
||||
|
||||
confidence_masks: Optional[List[image_module.Image]] = None
|
||||
category_mask: Optional[image_module.Image] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class InteractiveSegmenterOptions:
|
||||
"""Options for the interactive segmenter task.
|
||||
|
||||
Attributes:
|
||||
base_options: Base options for the interactive segmenter task.
|
||||
output_type: The output mask type allows specifying the type of
|
||||
post-processing to perform on the raw model results.
|
||||
output_confidence_masks: Whether to output confidence masks.
|
||||
output_category_mask: Whether to output category mask.
|
||||
"""
|
||||
|
||||
class OutputType(enum.Enum):
|
||||
UNSPECIFIED = 0
|
||||
CATEGORY_MASK = 1
|
||||
CONFIDENCE_MASK = 2
|
||||
|
||||
base_options: _BaseOptions
|
||||
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
||||
output_confidence_masks: bool = True
|
||||
output_category_mask: bool = False
|
||||
|
||||
@doc_controls.do_not_generate_docs
|
||||
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
||||
"""Generates an InteractiveSegmenterOptions protobuf object."""
|
||||
base_options_proto = self.base_options.to_pb2()
|
||||
base_options_proto.use_stream_mode = False
|
||||
segmenter_options_proto = _SegmenterOptionsProto(
|
||||
output_type=self.output_type.value
|
||||
)
|
||||
segmenter_options_proto = _SegmenterOptionsProto()
|
||||
return _ImageSegmenterGraphOptionsProto(
|
||||
base_options=base_options_proto,
|
||||
segmenter_options=segmenter_options_proto,
|
||||
|
@ -192,6 +203,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
RuntimeError: If other types of error occurred.
|
||||
"""
|
||||
|
||||
output_streams = [
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
]
|
||||
|
||||
if options.output_confidence_masks:
|
||||
output_streams.append(
|
||||
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
|
||||
)
|
||||
|
||||
if options.output_category_mask:
|
||||
output_streams.append(
|
||||
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
|
||||
)
|
||||
|
||||
task_info = _TaskInfo(
|
||||
task_graph=_TASK_GRAPH_NAME,
|
||||
input_streams=[
|
||||
|
@ -199,10 +224,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
':'.join([_ROI_TAG, _ROI_STREAM_NAME]),
|
||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=[
|
||||
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
|
||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||
],
|
||||
output_streams=output_streams,
|
||||
task_options=options,
|
||||
)
|
||||
return cls(
|
||||
|
@ -216,7 +238,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
image: image_module.Image,
|
||||
roi: RegionOfInterest,
|
||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||
) -> List[image_module.Image]:
|
||||
) -> InteractiveSegmenterResult:
|
||||
"""Performs the actual segmentation task on the provided MediaPipe Image.
|
||||
|
||||
The image can be of any size with format RGB.
|
||||
|
@ -248,7 +270,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
|||
normalized_rect.to_pb2()
|
||||
),
|
||||
})
|
||||
segmentation_result = packet_getter.get_image_list(
|
||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
||||
)
|
||||
segmentation_result = InteractiveSegmenterResult()
|
||||
|
||||
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
|
||||
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||
)
|
||||
|
||||
if _CATEGORY_MASK_STREAM_NAME in output_packets:
|
||||
segmentation_result.category_mask = packet_getter.get_image(
|
||||
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||
)
|
||||
|
||||
return segmentation_result
|
||||
|
|
|
@ -59,13 +59,12 @@ export function drawCategoryMask(
|
|||
const isFloatArray = image instanceof Float32Array;
|
||||
for (let i = 0; i < image.length; i++) {
|
||||
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
||||
const color = COLOR_MAP[colorIndex];
|
||||
let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
||||
|
||||
// When we're given a confidence mask by accident, we just log and return.
|
||||
// TODO: We should fix this.
|
||||
if (!color) {
|
||||
// TODO: We should fix this.
|
||||
console.warn('No color for ', colorIndex);
|
||||
return;
|
||||
color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
||||
}
|
||||
|
||||
rgbaArray[4 * i] = color[0];
|
||||
|
|
11
mediapipe/tasks/web/vision/core/types.d.ts
vendored
11
mediapipe/tasks/web/vision/core/types.d.ts
vendored
|
@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke
|
|||
*/
|
||||
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
|
||||
|
||||
/**
|
||||
* A callback that receives the computed masks from the segmentation tasks. The
|
||||
* callback either receives a single element array with a category mask (as a
|
||||
* `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`).
|
||||
* The returned data is only valid for the duration of the callback. If
|
||||
* asynchronous processing is needed, all data needs to be copied before the
|
||||
* callback returns.
|
||||
*/
|
||||
export type SegmentationMaskCallback =
|
||||
(masks: SegmentationMask[], width: number, height: number) => void;
|
||||
|
||||
/**
|
||||
* A callback that receives an `ImageData` object from a Vision task. The
|
||||
* lifetime of the underlying data is limited to the duration of the callback.
|
||||
|
|
|
@ -19,7 +19,7 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
|
|||
// tslint:disable:class-as-namespace Using for easier import by 3P users
|
||||
|
||||
/**
|
||||
* A class containing the Pairs of landmark indices to be rendered with
|
||||
* A class containing the pairs of landmark indices to be rendered with
|
||||
* connections.
|
||||
*/
|
||||
export class FaceLandmarksConnections {
|
||||
|
|
|
@ -129,10 +129,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
* synchronously once the callback returns. Only use this method when the
|
||||
* FaceStylizer is created with the image running mode.
|
||||
*
|
||||
* The input image can be of any size. To ensure that the output image has
|
||||
* reasonable quality, the stylized output image size is determined by the
|
||||
* model output size.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param callback The callback that is invoked with the stylized image. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
|
@ -153,11 +149,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
* If both are specified, the crop around the region-of-interest is extracted
|
||||
* first, then the specified rotation is applied to the crop.
|
||||
*
|
||||
* The input image can be of any size. To ensure that the output image has
|
||||
* reasonable quality, the stylized output image size is the smaller of the
|
||||
* model output size and the size of the 'regionOfInterest' specified in
|
||||
* 'imageProcessingOptions'.
|
||||
*
|
||||
* @param image An image to process.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
|
@ -192,9 +183,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
* frame's timestamp (in milliseconds). The input timestamps must be
|
||||
* monotonically increasing.
|
||||
*
|
||||
* To ensure that the output image has reasonable quality, the stylized
|
||||
* output image size is determined by the model output size.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param callback The callback that is invoked with the stylized image. The
|
||||
|
@ -221,10 +209,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
* frame's timestamp (in milliseconds). The input timestamps must be
|
||||
* monotonically increasing.
|
||||
*
|
||||
* To ensure that the output image has reasonable quality, the stylized
|
||||
* output image size is the smaller of the model output size and the size of
|
||||
* the 'regionOfInterest' specified in 'imageProcessingOptions'.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
|
@ -278,8 +262,12 @@ export class FaceStylizer extends VisionTaskRunner {
|
|||
|
||||
this.graphRunner.attachImageListener(
|
||||
STYLIZED_IMAGE_STREAM, (image, timestamp) => {
|
||||
const imageData = this.convertToImageData(image);
|
||||
this.userCallback(imageData, image.width, image.height);
|
||||
if (image.data instanceof WebGLTexture) {
|
||||
this.userCallback(image.data, image.width, image.height);
|
||||
} else {
|
||||
const imageData = this.convertToImageData(image);
|
||||
this.userCallback(imageData, image.width, image.height);
|
||||
}
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
|
|
|
@ -34,6 +34,7 @@ mediapipe_ts_library(
|
|||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||
"//mediapipe/tasks/web/vision/hand_landmarker:hand_landmarks_connections",
|
||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -31,6 +31,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/
|
|||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
|
@ -72,6 +73,12 @@ export class GestureRecognizer extends VisionTaskRunner {
|
|||
private readonly handGestureRecognizerGraphOptions:
|
||||
HandGestureRecognizerGraphOptions;
|
||||
|
||||
/**
|
||||
* An array containing the pairs of hand landmark indices to be rendered with
|
||||
* connections.
|
||||
*/
|
||||
static HAND_CONNECTIONS = HAND_CONNECTIONS;
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new gesture recognizer from the
|
||||
* provided options.
|
||||
|
|
|
@ -16,6 +16,7 @@ mediapipe_ts_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":hand_landmarker_types",
|
||||
":hand_landmarks_connections",
|
||||
"//mediapipe/framework:calculator_jspb_proto",
|
||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||
|
@ -72,3 +73,9 @@ jasmine_node_test(
|
|||
tags = ["nomsan"],
|
||||
deps = [":hand_landmarker_test_lib"],
|
||||
)
|
||||
|
||||
mediapipe_ts_library(
|
||||
name = "hand_landmarks_connections",
|
||||
srcs = ["hand_landmarks_connections.ts"],
|
||||
deps = ["//mediapipe/tasks/web/vision/core:types"],
|
||||
)
|
||||
|
|
|
@ -27,6 +27,7 @@ import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/con
|
|||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
|
@ -63,6 +64,12 @@ export class HandLandmarker extends VisionTaskRunner {
|
|||
HandLandmarksDetectorGraphOptions;
|
||||
private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
|
||||
|
||||
/**
|
||||
* An array containing the pairs of hand landmark indices to be rendered with
|
||||
* connections.
|
||||
*/
|
||||
static HAND_CONNECTIONS = HAND_CONNECTIONS;
|
||||
|
||||
/**
|
||||
* Initializes the Wasm runtime and creates a new `HandLandmarker` from the
|
||||
* provided options.
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import {Connection} from '../../../../tasks/web/vision/core/types';
|
||||
|
||||
/**
|
||||
* An array containing the pairs of hand landmark indices to be rendered with
|
||||
* connections.
|
||||
*/
|
||||
export const HAND_CONNECTIONS: Connection[] = [
|
||||
{start: 0, end: 1}, {start: 1, end: 2}, {start: 2, end: 3},
|
||||
{start: 3, end: 4}, {start: 0, end: 5}, {start: 5, end: 6},
|
||||
{start: 6, end: 7}, {start: 7, end: 8}, {start: 5, end: 9},
|
||||
{start: 9, end: 10}, {start: 10, end: 11}, {start: 11, end: 12},
|
||||
{start: 9, end: 13}, {start: 13, end: 14}, {start: 14, end: 15},
|
||||
{start: 15, end: 16}, {start: 13, end: 17}, {start: 0, end: 17},
|
||||
{start: 17, end: 18}, {start: 18, end: 19}, {start: 19, end: 20}
|
||||
];
|
|
@ -29,7 +29,10 @@ mediapipe_ts_library(
|
|||
|
||||
mediapipe_ts_declaration(
|
||||
name = "image_segmenter_types",
|
||||
srcs = ["image_segmenter_options.d.ts"],
|
||||
srcs = [
|
||||
"image_segmenter_options.d.ts",
|
||||
"image_segmenter_result.d.ts",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
|
|
|
@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
|||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||
import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {LabelMapItem} from '../../../../util/label_map_pb';
|
||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||
import {ImageSegmenterResult} from './image_segmenter_result';
|
||||
|
||||
export * from './image_segmenter_options';
|
||||
export {SegmentationMask, SegmentationMaskCallback};
|
||||
export * from './image_segmenter_result';
|
||||
export {SegmentationMask};
|
||||
export {ImageSource}; // Used in the public API
|
||||
|
||||
const IMAGE_STREAM = 'image_in';
|
||||
const NORM_RECT_STREAM = 'norm_rect';
|
||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const IMAGE_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||
'mediapipe.tasks.TensorsToSegmentationCalculator';
|
||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/**
|
||||
* A callback that receives the computed masks from the image segmenter. The
|
||||
* returned data is only valid for the duration of the callback. If
|
||||
* asynchronous processing is needed, all data needs to be copied before the
|
||||
* callback returns.
|
||||
*/
|
||||
export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void;
|
||||
|
||||
/** Performs image segmentation on images. */
|
||||
export class ImageSegmenter extends VisionTaskRunner {
|
||||
private userCallback: SegmentationMaskCallback = () => {};
|
||||
private result: ImageSegmenterResult = {width: 0, height: 0};
|
||||
private labels: string[] = [];
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||
|
||||
|
@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
this.options.setBaseOptions(new BaseOptionsProto());
|
||||
}
|
||||
|
||||
|
||||
protected override get baseOptions(): BaseOptionsProto {
|
||||
return this.options.getBaseOptions()!;
|
||||
}
|
||||
|
@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
this.options.clearDisplayNamesLocale();
|
||||
}
|
||||
|
||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
||||
} else {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
||||
if ('outputCategoryMask' in options) {
|
||||
this.outputCategoryMask =
|
||||
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
}
|
||||
|
||||
if ('outputConfidenceMasks' in options) {
|
||||
this.outputConfidenceMasks =
|
||||
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
}
|
||||
|
||||
return super.applyOptions(options);
|
||||
|
@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segment(image: ImageSource, callback: SegmentationMaskCallback): void;
|
||||
segment(image: ImageSource, callback: ImageSegmenterCallack): void;
|
||||
/**
|
||||
* Performs image segmentation on the provided single image and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
|
@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
*/
|
||||
segment(
|
||||
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
callback: ImageSegmenterCallack): void;
|
||||
segment(
|
||||
image: ImageSource,
|
||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||
SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
ImageSegmenterCallack,
|
||||
callback?: ImageSegmenterCallack): void {
|
||||
const imageProcessingOptions =
|
||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
{};
|
||||
|
||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
const userCallback =
|
||||
typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
callback!;
|
||||
|
||||
this.reset();
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
this.userCallback = () => {};
|
||||
userCallback(this.result);
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource, timestamp: number,
|
||||
callback: ImageSegmenterCallack): void;
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||
timestamp: number, callback: ImageSegmenterCallack): void;
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource,
|
||||
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
||||
timestampOrCallback: number|ImageSegmenterCallack,
|
||||
callback?: ImageSegmenterCallack): void {
|
||||
const imageProcessingOptions =
|
||||
typeof timestampOrImageProcessingOptions !== 'number' ?
|
||||
timestampOrImageProcessingOptions :
|
||||
{};
|
||||
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||
timestampOrImageProcessingOptions :
|
||||
timestampOrCallback as number;
|
||||
const userCallback = typeof timestampOrCallback === 'function' ?
|
||||
timestampOrCallback :
|
||||
callback!;
|
||||
|
||||
this.reset();
|
||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||
userCallback(this.result);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
return this.labels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource, timestamp: number,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
/**
|
||||
* Performs image segmentation on the provided video frame and invokes the
|
||||
* callback with the response. The method returns synchronously once the
|
||||
* callback returns. Only use this method when the ImageSegmenter is
|
||||
* created with running mode `video`.
|
||||
*
|
||||
* @param videoFrame A video frame to process.
|
||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||
* to process the input image before running inference.
|
||||
* @param timestamp The timestamp of the current frame, in ms.
|
||||
* @param callback The callback that is invoked with the segmented masks. The
|
||||
* lifetime of the returned data is only guaranteed for the duration of the
|
||||
* callback.
|
||||
*/
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||
timestamp: number, callback: SegmentationMaskCallback): void;
|
||||
segmentForVideo(
|
||||
videoFrame: ImageSource,
|
||||
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
||||
timestampOrCallback: number|SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
const imageProcessingOptions =
|
||||
typeof timestampOrImageProcessingOptions !== 'number' ?
|
||||
timestampOrImageProcessingOptions :
|
||||
{};
|
||||
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||
timestampOrImageProcessingOptions :
|
||||
timestampOrCallback as number;
|
||||
|
||||
this.userCallback = typeof timestampOrCallback === 'function' ?
|
||||
timestampOrCallback :
|
||||
callback!;
|
||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||
this.userCallback = () => {};
|
||||
private reset(): void {
|
||||
this.result = {width: 0, height: 0};
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
|
@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
const graphConfig = new CalculatorGraphConfig();
|
||||
graphConfig.addInputStream(IMAGE_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner {
|
|||
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
|
||||
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||
segmenterNode.addOutputStream(
|
||||
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
|
||||
segmenterNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(segmenterNode);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
|
||||
if (masks.length === 0) {
|
||||
this.userCallback([], 0, 0);
|
||||
} else {
|
||||
this.userCallback(
|
||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
||||
}
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
GROUPED_SEGMENTATIONS_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
if (this.outputConfidenceMasks) {
|
||||
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||
segmenterNode.addOutputStream(
|
||||
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||
this.result.confidenceMasks = masks.map(m => m.data);
|
||||
if (masks.length >= 0) {
|
||||
this.result.width = masks[0].width;
|
||||
this.result.height = masks[0].height;
|
||||
}
|
||||
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
}
|
||||
|
||||
if (this.outputCategoryMask) {
|
||||
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||
|
||||
this.graphRunner.attachImageListener(
|
||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||
this.result.categoryMask = mask.data;
|
||||
this.result.width = mask.width;
|
||||
this.result.height = mask.height;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
CATEGORY_MASK_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
}
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions {
|
|||
*/
|
||||
displayNamesLocale?: string|undefined;
|
||||
|
||||
/**
|
||||
* The output type of segmentation results.
|
||||
*
|
||||
* The two supported modes are:
|
||||
* - Category Mask: Gives a single output mask where each pixel represents
|
||||
* the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
||||
* each mask, the pixel represents the prediction
|
||||
* confidence, usually in the [0.0, 0.1] range.
|
||||
*
|
||||
* Defaults to `CATEGORY_MASK`.
|
||||
*/
|
||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
||||
/** Whether to output confidence masks. Defaults to true. */
|
||||
outputConfidenceMasks?: boolean|undefined;
|
||||
|
||||
/** Whether to output the category masks. Defaults to false. */
|
||||
outputCategoryMask?: boolean|undefined;
|
||||
}
|
||||
|
|
37
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
vendored
Normal file
37
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/** The output result of ImageSegmenter. */
|
||||
export declare interface ImageSegmenterResult {
|
||||
/**
|
||||
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
|
||||
* pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||
*/
|
||||
confidenceMasks?: Float32Array[]|WebGLTexture[];
|
||||
|
||||
/**
|
||||
* A category mask as a Uint8ClampedArray or WebGLTexture where each
|
||||
* pixel represents the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
*/
|
||||
categoryMask?: Uint8ClampedArray|WebGLTexture;
|
||||
|
||||
/** The width of the masks. */
|
||||
width: number;
|
||||
|
||||
/** The height of the masks. */
|
||||
height: number;
|
||||
}
|
|
@ -18,7 +18,7 @@ import 'jasmine';
|
|||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||
|
||||
import {ImageSegmenter} from './image_segmenter';
|
||||
|
@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
graph: CalculatorGraphConfig|undefined;
|
||||
|
||||
fakeWasmModule: SpyWasmModule;
|
||||
imageVectorListener:
|
||||
categoryMaskListener:
|
||||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
|
||||
constructor() {
|
||||
|
@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
|||
this.fakeWasmModule =
|
||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||
|
||||
this.attachListenerSpies[0] =
|
||||
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('category_mask');
|
||||
this.categoryMaskListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[1] =
|
||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('segmented_masks');
|
||||
this.imageVectorListener = listener;
|
||||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
|
@ -63,17 +70,18 @@ describe('ImageSegmenter', () => {
|
|||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(imageSegmenter);
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
|
||||
// Verify default options
|
||||
expect(imageSegmenter.categoryMaskListener).not.toBeDefined();
|
||||
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
|
||||
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
|
||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
});
|
||||
|
||||
it('can use custom models', async () => {
|
||||
|
@ -100,9 +108,11 @@ describe('ImageSegmenter', () => {
|
|||
});
|
||||
|
||||
it('merges options', async () => {
|
||||
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
||||
await imageSegmenter.setOptions(
|
||||
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
||||
verifyGraph(
|
||||
imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']);
|
||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||
});
|
||||
|
||||
|
@ -115,22 +125,13 @@ describe('ImageSegmenter', () => {
|
|||
defaultValue: unknown;
|
||||
}
|
||||
|
||||
const testCases: TestCase[] = [
|
||||
{
|
||||
optionName: 'displayNamesLocale',
|
||||
fieldPath: ['displayNamesLocale'],
|
||||
userValue: 'en',
|
||||
graphValue: 'en',
|
||||
defaultValue: 'en'
|
||||
},
|
||||
{
|
||||
optionName: 'outputType',
|
||||
fieldPath: ['segmenterOptions', 'outputType'],
|
||||
userValue: 'CONFIDENCE_MASK',
|
||||
graphValue: 2,
|
||||
defaultValue: 1
|
||||
},
|
||||
];
|
||||
const testCases: TestCase[] = [{
|
||||
optionName: 'displayNamesLocale',
|
||||
fieldPath: ['displayNamesLocale'],
|
||||
userValue: 'en',
|
||||
graphValue: 'en',
|
||||
defaultValue: 'en'
|
||||
}];
|
||||
|
||||
for (const testCase of testCases) {
|
||||
it(`can set ${testCase.optionName}`, async () => {
|
||||
|
@ -158,27 +159,31 @@ describe('ImageSegmenter', () => {
|
|||
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||
});
|
||||
|
||||
it('supports category masks', (done) => {
|
||||
it('supports category mask', async () => {
|
||||
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
||||
|
||||
await imageSegmenter.setOptions(
|
||||
{outputCategoryMask: true, outputConfidenceMasks: false});
|
||||
|
||||
// Pass the test data to our listener
|
||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
imageSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask, width: 2, height: 2},
|
||||
],
|
||||
/* timestamp= */ 1337);
|
||||
expect(imageSegmenter.categoryMaskListener).toBeDefined();
|
||||
imageSegmenter.categoryMaskListener!
|
||||
({data: mask, width: 2, height: 2},
|
||||
/* timestamp= */ 1337);
|
||||
});
|
||||
|
||||
// Invoke the image segmenter
|
||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(1);
|
||||
expect(masks[0]).toEqual(mask);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
done();
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(result.categoryMask).toEqual(mask);
|
||||
expect(result.confidenceMasks).not.toBeDefined();
|
||||
expect(result.width).toEqual(2);
|
||||
expect(result.height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -186,12 +191,13 @@ describe('ImageSegmenter', () => {
|
|||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||
|
||||
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
await imageSegmenter.setOptions(
|
||||
{outputCategoryMask: false, outputConfidenceMasks: true});
|
||||
|
||||
// Pass the test data to our listener
|
||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(imageSegmenter);
|
||||
imageSegmenter.imageVectorListener!(
|
||||
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||
imageSegmenter.confidenceMasksListener!(
|
||||
[
|
||||
{data: mask1, width: 2, height: 2},
|
||||
{data: mask2, width: 2, height: 2},
|
||||
|
@ -201,13 +207,49 @@ describe('ImageSegmenter', () => {
|
|||
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(2);
|
||||
expect(masks[0]).toEqual(mask1);
|
||||
expect(masks[1]).toEqual(mask2);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
expect(result.categoryMask).not.toBeDefined();
|
||||
expect(result.confidenceMasks).toEqual([mask1, mask2]);
|
||||
expect(result.width).toEqual(2);
|
||||
expect(result.height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('supports combined category and confidence masks', async () => {
|
||||
const categoryMask = new Uint8ClampedArray([1, 0]);
|
||||
const confidenceMask1 = new Float32Array([0.0, 1.0]);
|
||||
const confidenceMask2 = new Float32Array([1.0, 0.0]);
|
||||
|
||||
await imageSegmenter.setOptions(
|
||||
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||
|
||||
// Pass the test data to our listener
|
||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
expect(imageSegmenter.categoryMaskListener).toBeDefined();
|
||||
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||
imageSegmenter.categoryMaskListener!
|
||||
({data: categoryMask, width: 1, height: 1}, 1337);
|
||||
imageSegmenter.confidenceMasksListener!(
|
||||
[
|
||||
{data: confidenceMask1, width: 1, height: 1},
|
||||
{data: confidenceMask2, width: 1, height: 1},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||
expect(result.categoryMask).toEqual(categoryMask);
|
||||
expect(result.confidenceMasks).toEqual([
|
||||
confidenceMask1, confidenceMask2
|
||||
]);
|
||||
expect(result.width).toEqual(1);
|
||||
expect(result.height).toEqual(1);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -30,7 +30,10 @@ mediapipe_ts_library(
|
|||
|
||||
mediapipe_ts_declaration(
|
||||
name = "interactive_segmenter_types",
|
||||
srcs = ["interactive_segmenter_options.d.ts"],
|
||||
srcs = [
|
||||
"interactive_segmenter_options.d.ts",
|
||||
"interactive_segmenter_result.d.ts",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/tasks/web/core",
|
||||
"//mediapipe/tasks/web/core:classifier_options",
|
||||
|
|
|
@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
|||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||
import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types';
|
||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||
import {Color as ColorProto} from '../../../../util/color_pb';
|
||||
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
|
@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
|
|||
// Placeholder for internal dependency on trusted resource url
|
||||
|
||||
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
|
||||
import {InteractiveSegmenterResult} from './interactive_segmenter_result';
|
||||
|
||||
export * from './interactive_segmenter_options';
|
||||
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest};
|
||||
export * from './interactive_segmenter_result';
|
||||
export {SegmentationMask, RegionOfInterest};
|
||||
export {ImageSource};
|
||||
|
||||
const IMAGE_IN_STREAM = 'image_in';
|
||||
const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
||||
const ROI_IN_STREAM = 'roi_in';
|
||||
const IMAGE_OUT_STREAM = 'image_out';
|
||||
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||
const IMAGEA_SEGMENTER_GRAPH =
|
||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
|
||||
|
||||
// The OSS JS API does not support the builder pattern.
|
||||
// tslint:disable:jspb-use-builder-pattern
|
||||
|
||||
/**
|
||||
* A callback that receives the computed masks from the interactive segmenter.
|
||||
* The returned data is only valid for the duration of the callback. If
|
||||
* asynchronous processing is needed, all data needs to be copied before the
|
||||
* callback returns.
|
||||
*/
|
||||
export type InteractiveSegmenterCallack =
|
||||
(result: InteractiveSegmenterResult) => void;
|
||||
|
||||
/**
|
||||
* Performs interactive segmentation on images.
|
||||
*
|
||||
|
@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH =
|
|||
* - batch is always 1
|
||||
*/
|
||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||
private userCallback: SegmentationMaskCallback = () => {};
|
||||
private result: InteractiveSegmenterResult = {width: 0, height: 0};
|
||||
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||
|
||||
|
@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
* @return A Promise that resolves when the settings have been applied.
|
||||
*/
|
||||
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
|
||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
||||
} else {
|
||||
this.segmenterOptions.setOutputType(
|
||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
||||
if ('outputCategoryMask' in options) {
|
||||
this.outputCategoryMask =
|
||||
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||
}
|
||||
|
||||
if ('outputConfidenceMasks' in options) {
|
||||
this.outputConfidenceMasks =
|
||||
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||
}
|
||||
|
||||
return super.applyOptions(options);
|
||||
|
@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
*/
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
callback: InteractiveSegmenterCallack): void;
|
||||
/**
|
||||
* Performs interactive segmentation on the provided single image and invokes
|
||||
* the callback with the response. The `roi` parameter is used to represent a
|
||||
|
@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptions: ImageProcessingOptions,
|
||||
callback: SegmentationMaskCallback): void;
|
||||
callback: InteractiveSegmenterCallack): void;
|
||||
segment(
|
||||
image: ImageSource, roi: RegionOfInterest,
|
||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||
SegmentationMaskCallback,
|
||||
callback?: SegmentationMaskCallback): void {
|
||||
InteractiveSegmenterCallack,
|
||||
callback?: InteractiveSegmenterCallack): void {
|
||||
const imageProcessingOptions =
|
||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
{};
|
||||
|
||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
const userCallback =
|
||||
typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||
imageProcessingOptionsOrCallback :
|
||||
callback!;
|
||||
|
||||
this.reset();
|
||||
this.processRenderData(roi, this.getSynctheticTimestamp());
|
||||
this.processImageData(image, imageProcessingOptions);
|
||||
this.userCallback = () => {};
|
||||
userCallback(this.result);
|
||||
}
|
||||
|
||||
private reset(): void {
|
||||
this.result = {width: 0, height: 0};
|
||||
}
|
||||
|
||||
/** Updates the MediaPipe graph configuration. */
|
||||
|
@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
graphConfig.addInputStream(IMAGE_IN_STREAM);
|
||||
graphConfig.addInputStream(ROI_IN_STREAM);
|
||||
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
|
||||
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
|
||||
|
||||
const calculatorOptions = new CalculatorOptions();
|
||||
calculatorOptions.setExtension(
|
||||
|
@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
|||
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
|
||||
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
|
||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
|
||||
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
|
||||
segmenterNode.setOptions(calculatorOptions);
|
||||
|
||||
graphConfig.addNode(segmenterNode);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
IMAGE_OUT_STREAM, (masks, timestamp) => {
|
||||
if (masks.length === 0) {
|
||||
this.userCallback([], 0, 0);
|
||||
} else {
|
||||
this.userCallback(
|
||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
||||
}
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
if (this.outputConfidenceMasks) {
|
||||
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||
segmenterNode.addOutputStream(
|
||||
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||
|
||||
this.graphRunner.attachImageVectorListener(
|
||||
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||
this.result.confidenceMasks = masks.map(m => m.data);
|
||||
if (masks.length >= 0) {
|
||||
this.result.width = masks[0].width;
|
||||
this.result.height = masks[0].height;
|
||||
}
|
||||
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
}
|
||||
|
||||
if (this.outputCategoryMask) {
|
||||
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||
|
||||
this.graphRunner.attachImageListener(
|
||||
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||
this.result.categoryMask = mask.data;
|
||||
this.result.width = mask.width;
|
||||
this.result.height = mask.height;
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
this.graphRunner.attachEmptyPacketListener(
|
||||
CATEGORY_MASK_STREAM, timestamp => {
|
||||
this.setLatestOutputTimestamp(timestamp);
|
||||
});
|
||||
}
|
||||
|
||||
const binaryGraph = graphConfig.serializeBinary();
|
||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||
|
|
|
@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'
|
|||
|
||||
/** Options to configure the MediaPipe Interactive Segmenter Task */
|
||||
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
|
||||
/**
|
||||
* The output type of segmentation results.
|
||||
*
|
||||
* The two supported modes are:
|
||||
* - Category Mask: Gives a single output mask where each pixel represents
|
||||
* the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
||||
* each mask, the pixel represents the prediction
|
||||
* confidence, usually in the [0.0, 0.1] range.
|
||||
*
|
||||
* Defaults to `CATEGORY_MASK`.
|
||||
*/
|
||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
||||
/** Whether to output confidence masks. Defaults to true. */
|
||||
outputConfidenceMasks?: boolean|undefined;
|
||||
|
||||
/** Whether to output the category masks. Defaults to false. */
|
||||
outputCategoryMask?: boolean|undefined;
|
||||
}
|
||||
|
|
37
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts
vendored
Normal file
37
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts
vendored
Normal file
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/** The output result of InteractiveSegmenter. */
|
||||
export declare interface InteractiveSegmenterResult {
|
||||
/**
|
||||
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
|
||||
* pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||
*/
|
||||
confidenceMasks?: Float32Array[]|WebGLTexture[];
|
||||
|
||||
/**
|
||||
* A category mask as a Uint8ClampedArray or WebGLTexture where each
|
||||
* pixel represents the class which the pixel in the original image was
|
||||
* predicted to belong to.
|
||||
*/
|
||||
categoryMask?: Uint8ClampedArray|WebGLTexture;
|
||||
|
||||
/** The width of the masks. */
|
||||
width: number;
|
||||
|
||||
/** The height of the masks. */
|
||||
height: number;
|
||||
}
|
|
@ -18,7 +18,7 @@ import 'jasmine';
|
|||
|
||||
// Placeholder for internal dependency on encodeByteArray
|
||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||
|
||||
|
@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
graph: CalculatorGraphConfig|undefined;
|
||||
|
||||
fakeWasmModule: SpyWasmModule;
|
||||
imageVectorListener:
|
||||
categoryMaskListener:
|
||||
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||
confidenceMasksListener:
|
||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||
lastRoi?: RenderDataProto;
|
||||
|
||||
|
@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
|||
this.fakeWasmModule =
|
||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||
|
||||
this.attachListenerSpies[0] =
|
||||
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('category_mask');
|
||||
this.categoryMaskListener = listener;
|
||||
});
|
||||
this.attachListenerSpies[1] =
|
||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||
.and.callFake((stream, listener) => {
|
||||
expect(stream).toEqual('image_out');
|
||||
this.imageVectorListener = listener;
|
||||
expect(stream).toEqual('confidence_masks');
|
||||
this.confidenceMasksListener = listener;
|
||||
});
|
||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||
|
@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => {
|
|||
|
||||
it('initializes graph', async () => {
|
||||
verifyGraph(interactiveSegmenter);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
|
||||
// Verify default options
|
||||
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
});
|
||||
|
||||
it('reloads graph when settings are changed', async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputConfidenceMasks: true, outputCategoryMask: false});
|
||||
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]);
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputConfidenceMasks: false, outputCategoryMask: true});
|
||||
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||
});
|
||||
|
||||
it('can use custom models', async () => {
|
||||
|
@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
|
||||
describe('setOptions()', () => {
|
||||
const fieldPath = ['segmenterOptions', 'outputType'];
|
||||
|
||||
it(`can set outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
});
|
||||
|
||||
it(`can clear outputType`, async () => {
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
||||
await interactiveSegmenter.setOptions({outputType: undefined});
|
||||
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
|
||||
});
|
||||
});
|
||||
|
||||
it('doesn\'t support region of interest', () => {
|
||||
expect(() => {
|
||||
interactiveSegmenter.segment(
|
||||
|
@ -153,60 +147,99 @@ describe('InteractiveSegmenter', () => {
|
|||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
|
||||
});
|
||||
|
||||
it('supports category masks', (done) => {
|
||||
it('supports category mask', async () => {
|
||||
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
||||
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputCategoryMask: true, outputConfidenceMasks: false});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
[
|
||||
{data: mask, width: 2, height: 2},
|
||||
],
|
||||
/* timestamp= */ 1337);
|
||||
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||
interactiveSegmenter.categoryMaskListener!
|
||||
({data: mask, width: 2, height: 2},
|
||||
/* timestamp= */ 1337);
|
||||
});
|
||||
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(1);
|
||||
expect(masks[0]).toEqual(mask);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
done();
|
||||
});
|
||||
return new Promise<void>(resolve => {
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(result.categoryMask).toEqual(mask);
|
||||
expect(result.confidenceMasks).not.toBeDefined();
|
||||
expect(result.width).toEqual(2);
|
||||
expect(result.height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('supports confidence masks', async () => {
|
||||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||
|
||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputCategoryMask: false, outputConfidenceMasks: true});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
verifyListenersRegistered(interactiveSegmenter);
|
||||
interactiveSegmenter.imageVectorListener!(
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
interactiveSegmenter.confidenceMasksListener!(
|
||||
[
|
||||
{data: mask1, width: 2, height: 2},
|
||||
{data: mask2, width: 2, height: 2},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(result.categoryMask).not.toBeDefined();
|
||||
expect(result.confidenceMasks).toEqual([mask1, mask2]);
|
||||
expect(result.width).toEqual(2);
|
||||
expect(result.height).toEqual(2);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('supports combined category and confidence masks', async () => {
|
||||
const categoryMask = new Uint8ClampedArray([1, 0]);
|
||||
const confidenceMask1 = new Float32Array([0.0, 1.0]);
|
||||
const confidenceMask2 = new Float32Array([1.0, 0.0]);
|
||||
|
||||
await interactiveSegmenter.setOptions(
|
||||
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||
|
||||
// Pass the test data to our listener
|
||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||
interactiveSegmenter.categoryMaskListener!
|
||||
({data: categoryMask, width: 1, height: 1}, 1337);
|
||||
interactiveSegmenter.confidenceMasksListener!(
|
||||
[
|
||||
{data: confidenceMask1, width: 1, height: 1},
|
||||
{data: confidenceMask2, width: 1, height: 1},
|
||||
],
|
||||
1337);
|
||||
});
|
||||
|
||||
return new Promise<void>(resolve => {
|
||||
// Invoke the image segmenter
|
||||
interactiveSegmenter.segment(
|
||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
||||
{} as HTMLImageElement, ROI, result => {
|
||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||
.toHaveBeenCalled();
|
||||
expect(masks).toHaveSize(2);
|
||||
expect(masks[0]).toEqual(mask1);
|
||||
expect(masks[1]).toEqual(mask2);
|
||||
expect(width).toEqual(2);
|
||||
expect(height).toEqual(2);
|
||||
expect(result.categoryMask).toEqual(categoryMask);
|
||||
expect(result.confidenceMasks).toEqual([
|
||||
confidenceMask1, confidenceMask2
|
||||
]);
|
||||
expect(result.width).toEqual(1);
|
||||
expect(result.height).toEqual(1);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -56,8 +56,8 @@ bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y,
|
|||
VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0";
|
||||
}
|
||||
|
||||
*x_px = static_cast<int32>(round(normalized_x * image_width));
|
||||
*y_px = static_cast<int32>(round(normalized_y * image_height));
|
||||
*x_px = static_cast<int32_t>(round(normalized_x * image_width));
|
||||
*y_px = static_cast<int32_t>(round(normalized_y * image_height));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -43,7 +43,7 @@ ABSL_FLAG(std::string, system_cpu_max_freq_file,
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr uint32 kBufferLength = 64;
|
||||
constexpr uint32_t kBufferLength = 64;
|
||||
|
||||
absl::StatusOr<std::string> GetFilePath(int cpu) {
|
||||
if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
|
||||
|
@ -54,7 +54,7 @@ absl::StatusOr<std::string> GetFilePath(int cpu) {
|
|||
return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu);
|
||||
}
|
||||
|
||||
absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
|
||||
absl::StatusOr<uint64_t> GetCpuMaxFrequency(int cpu) {
|
||||
auto path_or_status = GetFilePath(cpu);
|
||||
if (!path_or_status.ok()) {
|
||||
return path_or_status.status();
|
||||
|
@ -65,7 +65,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
|
|||
char buffer[kBufferLength];
|
||||
file.getline(buffer, kBufferLength);
|
||||
file.close();
|
||||
uint64 frequency;
|
||||
uint64_t frequency;
|
||||
if (absl::SimpleAtoi(buffer, &frequency)) {
|
||||
return frequency;
|
||||
} else {
|
||||
|
@ -79,7 +79,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
|
|||
}
|
||||
|
||||
std::set<int> InferLowerOrHigherCoreIds(bool lower) {
|
||||
std::vector<std::pair<int, uint64>> cpu_freq_pairs;
|
||||
std::vector<std::pair<int, uint64_t>> cpu_freq_pairs;
|
||||
for (int cpu = 0; cpu < NumCPUCores(); ++cpu) {
|
||||
auto freq_or_status = GetCpuMaxFrequency(cpu);
|
||||
if (freq_or_status.ok()) {
|
||||
|
@ -90,12 +90,12 @@ std::set<int> InferLowerOrHigherCoreIds(bool lower) {
|
|||
return {};
|
||||
}
|
||||
|
||||
absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64>& left,
|
||||
const std::pair<int, uint64>& right) {
|
||||
absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64_t>& left,
|
||||
const std::pair<int, uint64_t>& right) {
|
||||
return (lower && left.second < right.second) ||
|
||||
(!lower && left.second > right.second);
|
||||
});
|
||||
uint64 edge_freq = cpu_freq_pairs[0].second;
|
||||
uint64_t edge_freq = cpu_freq_pairs[0].second;
|
||||
|
||||
std::set<int> inferred_cores;
|
||||
for (const auto& cpu_freq_pair : cpu_freq_pairs) {
|
||||
|
|
|
@ -89,12 +89,12 @@ void ImageFrameToYUVImage(const ImageFrame& image_frame, YUVImage* yuv_image) {
|
|||
const int uv_stride = (uv_width + 15) & ~15;
|
||||
const int y_size = y_stride * height;
|
||||
const int uv_size = uv_stride * uv_height;
|
||||
uint8* data =
|
||||
reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size * 2, 16));
|
||||
uint8_t* data =
|
||||
reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size * 2, 16));
|
||||
std::function<void()> deallocate = [data]() { aligned_free(data); };
|
||||
uint8* y = data;
|
||||
uint8* u = y + y_size;
|
||||
uint8* v = u + uv_size;
|
||||
uint8_t* y = data;
|
||||
uint8_t* u = y + y_size;
|
||||
uint8_t* v = u + uv_size;
|
||||
yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, //
|
||||
y, y_stride, //
|
||||
u, uv_stride, //
|
||||
|
@ -123,10 +123,11 @@ void ImageFrameToYUVNV12Image(const ImageFrame& image_frame,
|
|||
const int uv_stride = y_stride;
|
||||
const int uv_height = (height + 1) / 2;
|
||||
const int uv_size = uv_stride * uv_height;
|
||||
uint8* data = reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size, 16));
|
||||
uint8_t* data =
|
||||
reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size, 16));
|
||||
std::function<void()> deallocate = [data] { aligned_free(data); };
|
||||
uint8* y = data;
|
||||
uint8* uv = y + y_size;
|
||||
uint8_t* y = data;
|
||||
uint8_t* uv = y + y_size;
|
||||
yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv,
|
||||
uv_stride, nullptr, 0, width, height);
|
||||
const int rv = libyuv::I420ToNV12(
|
||||
|
@ -210,44 +211,44 @@ void YUVImageToImageFrameFromFormat(const YUVImage& yuv_image,
|
|||
}
|
||||
}
|
||||
|
||||
void SrgbToMpegYCbCr(const uint8 r, const uint8 g, const uint8 b, //
|
||||
uint8* y, uint8* cb, uint8* cr) {
|
||||
void SrgbToMpegYCbCr(const uint8_t r, const uint8_t g, const uint8_t b, //
|
||||
uint8_t* y, uint8_t* cb, uint8_t* cr) {
|
||||
// ITU-R BT.601 conversion from sRGB to YCbCr.
|
||||
// FastIntRound is used rather than SafeRound since the possible
|
||||
// range of values is [16,235] for Y and [16,240] for Cb and Cr and we
|
||||
// don't care about the rounding direction for values exactly between
|
||||
// two integers.
|
||||
*y = static_cast<uint8>(
|
||||
*y = static_cast<uint8_t>(
|
||||
mediapipe::MathUtil::FastIntRound(16.0 + //
|
||||
65.481 * r / 255.0 + //
|
||||
128.553 * g / 255.0 + //
|
||||
24.966 * b / 255.0));
|
||||
*cb = static_cast<uint8>(
|
||||
*cb = static_cast<uint8_t>(
|
||||
mediapipe::MathUtil::FastIntRound(128.0 + //
|
||||
-37.797 * r / 255.0 + //
|
||||
-74.203 * g / 255.0 + //
|
||||
112.0 * b / 255.0));
|
||||
*cr = static_cast<uint8>(
|
||||
*cr = static_cast<uint8_t>(
|
||||
mediapipe::MathUtil::FastIntRound(128.0 + //
|
||||
112.0 * r / 255.0 + //
|
||||
-93.786 * g / 255.0 + //
|
||||
-18.214 * b / 255.0));
|
||||
}
|
||||
|
||||
void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
|
||||
uint8* r, uint8* g, uint8* b) {
|
||||
void MpegYCbCrToSrgb(const uint8_t y, const uint8_t cb, const uint8_t cr, //
|
||||
uint8_t* r, uint8_t* g, uint8_t* b) {
|
||||
// ITU-R BT.601 conversion from YCbCr to sRGB
|
||||
// Use SafeRound since many MPEG YCbCr values do not correspond directly
|
||||
// to an sRGB value.
|
||||
*r = mediapipe::MathUtil::SafeRound<uint8, double>( //
|
||||
255.0 / 219.0 * (y - 16.0) + //
|
||||
*r = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
|
||||
255.0 / 219.0 * (y - 16.0) + //
|
||||
255.0 / 112.0 * 0.701 * (cr - 128.0));
|
||||
*g = mediapipe::MathUtil::SafeRound<uint8, double>(
|
||||
*g = mediapipe::MathUtil::SafeRound<uint8_t, double>(
|
||||
255.0 / 219.0 * (y - 16.0) - //
|
||||
255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - //
|
||||
255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0));
|
||||
*b = mediapipe::MathUtil::SafeRound<uint8, double>( //
|
||||
255.0 / 219.0 * (y - 16.0) + //
|
||||
*b = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
|
||||
255.0 / 219.0 * (y - 16.0) + //
|
||||
255.0 / 112.0 * 0.886 * (cb - 128.0));
|
||||
}
|
||||
|
||||
|
@ -260,15 +261,15 @@ void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
|
|||
|
||||
cv::Mat GetSrgbToLinearRgb16Lut() {
|
||||
cv::Mat lut(1, 256, CV_16UC1);
|
||||
uint16* ptr = lut.ptr<uint16>();
|
||||
uint16_t* ptr = lut.ptr<uint16_t>();
|
||||
constexpr double kUint8Max = 255.0;
|
||||
constexpr double kUint16Max = 65535.0;
|
||||
for (int i = 0; i < 256; ++i) {
|
||||
if (i < 0.04045 * kUint8Max) {
|
||||
ptr[i] = static_cast<uint16>(
|
||||
ptr[i] = static_cast<uint16_t>(
|
||||
(static_cast<double>(i) / kUint8Max / 12.92) * kUint16Max + .5);
|
||||
} else {
|
||||
ptr[i] = static_cast<uint16>(
|
||||
ptr[i] = static_cast<uint16_t>(
|
||||
pow((static_cast<double>(i) / kUint8Max + 0.055) / 1.055, 2.4) *
|
||||
kUint16Max +
|
||||
.5);
|
||||
|
@ -279,15 +280,15 @@ cv::Mat GetSrgbToLinearRgb16Lut() {
|
|||
|
||||
cv::Mat GetLinearRgb16ToSrgbLut() {
|
||||
cv::Mat lut(1, 65536, CV_8UC1);
|
||||
uint8* ptr = lut.ptr<uint8>();
|
||||
uint8_t* ptr = lut.ptr<uint8_t>();
|
||||
constexpr double kUint8Max = 255.0;
|
||||
constexpr double kUint16Max = 65535.0;
|
||||
for (int i = 0; i < 65536; ++i) {
|
||||
if (i < 0.0031308 * kUint16Max) {
|
||||
ptr[i] = static_cast<uint8>(
|
||||
ptr[i] = static_cast<uint8_t>(
|
||||
(static_cast<double>(i) / kUint16Max * 12.92) * kUint8Max + .5);
|
||||
} else {
|
||||
ptr[i] = static_cast<uint8>(
|
||||
ptr[i] = static_cast<uint8_t>(
|
||||
(1.055 * pow(static_cast<double>(i) / kUint16Max, 1.0 / 2.4) - .055) *
|
||||
kUint8Max +
|
||||
.5);
|
||||
|
@ -306,13 +307,13 @@ void LinearRgb16ToSrgb(const cv::Mat& source, cv::Mat* destination) {
|
|||
destination->create(source.size(), CV_8UC(source.channels()));
|
||||
|
||||
static const cv::Mat kLut = GetLinearRgb16ToSrgbLut();
|
||||
const uint8* lookup_table_ptr = kLut.ptr<uint8>();
|
||||
const uint8_t* lookup_table_ptr = kLut.ptr<uint8_t>();
|
||||
const int num_channels = source.channels();
|
||||
for (int row = 0; row < source.rows; ++row) {
|
||||
for (int col = 0; col < source.cols; ++col) {
|
||||
for (int channel = 0; channel < num_channels; ++channel) {
|
||||
uint8* ptr = destination->ptr<uint8>(row);
|
||||
const uint16* ptr16 = source.ptr<uint16>(row);
|
||||
uint8_t* ptr = destination->ptr<uint8_t>(row);
|
||||
const uint16_t* ptr16 = source.ptr<uint16_t>(row);
|
||||
ptr[col * num_channels + channel] =
|
||||
lookup_table_ptr[ptr16[col * num_channels + channel]];
|
||||
}
|
||||
|
|
|
@ -43,14 +43,14 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
|
|||
|
||||
Packet MakeImageFramePacket(cv::Mat input, int timestamp) {
|
||||
ImageFrame input_image(GetImageFormat(input.channels()), input.cols,
|
||||
input.rows, input.step, input.data, [](uint8*) {});
|
||||
input.rows, input.step, input.data, [](uint8_t*) {});
|
||||
return MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0));
|
||||
}
|
||||
|
||||
Packet MakeImagePacket(cv::Mat input, int timestamp) {
|
||||
mediapipe::Image input_image(std::make_shared<mediapipe::ImageFrame>(
|
||||
GetImageFormat(input.channels()), input.cols, input.rows, input.step,
|
||||
input.data, [](uint8*) {}));
|
||||
input.data, [](uint8_t*) {}));
|
||||
return MakePacket<mediapipe::Image>(std::move(input_image)).At(Timestamp(0));
|
||||
}
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles(
|
||||
absl::StatusOr<proto_ns::Map<int64_t, LabelMapItem>> BuildLabelMapFromFiles(
|
||||
absl::string_view labels_file_contents,
|
||||
absl::string_view display_names_file) {
|
||||
if (labels_file_contents.empty()) {
|
||||
|
@ -68,7 +68,7 @@ absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles(
|
|||
label_map_items[i].set_display_name(display_names[i]);
|
||||
}
|
||||
}
|
||||
proto_ns::Map<int64, LabelMapItem> label_map;
|
||||
proto_ns::Map<int64_t, LabelMapItem> label_map;
|
||||
for (int i = 0; i < label_map_items.size(); ++i) {
|
||||
label_map[i] = label_map_items[i];
|
||||
}
|
||||
|
|
4
third_party/flatbuffers/BUILD.bazel
vendored
4
third_party/flatbuffers/BUILD.bazel
vendored
|
@ -45,12 +45,16 @@ filegroup(
|
|||
"include/flatbuffers/bfbs_generator.h",
|
||||
"include/flatbuffers/buffer.h",
|
||||
"include/flatbuffers/buffer_ref.h",
|
||||
"include/flatbuffers/code_generator.h",
|
||||
"include/flatbuffers/code_generators.h",
|
||||
"include/flatbuffers/default_allocator.h",
|
||||
"include/flatbuffers/detached_buffer.h",
|
||||
"include/flatbuffers/flatbuffer_builder.h",
|
||||
"include/flatbuffers/flatbuffers.h",
|
||||
"include/flatbuffers/flatc.h",
|
||||
"include/flatbuffers/flex_flat_util.h",
|
||||
"include/flatbuffers/flexbuffers.h",
|
||||
"include/flatbuffers/grpc.h",
|
||||
"include/flatbuffers/hash.h",
|
||||
"include/flatbuffers/idl.h",
|
||||
"include/flatbuffers/minireflect.h",
|
||||
|
|
8
third_party/flatbuffers/workspace.bzl
vendored
8
third_party/flatbuffers/workspace.bzl
vendored
|
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
|||
def repo():
|
||||
third_party_http_archive(
|
||||
name = "flatbuffers",
|
||||
strip_prefix = "flatbuffers-2.0.6",
|
||||
sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9",
|
||||
strip_prefix = "flatbuffers-23.1.21",
|
||||
sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
|
||||
urls = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v2.0.6.tar.gz",
|
||||
"https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz",
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
||||
"https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
||||
],
|
||||
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
||||
delete = ["build_defs.bzl", "BUILD.bazel"],
|
||||
|
|
48
third_party/wasm_files.bzl
vendored
48
third_party/wasm_files.bzl
vendored
|
@ -12,72 +12,72 @@ def wasm_files():
|
|||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
|
||||
sha256 = "0eca68e2291a548b734bcab5db4c9e6b997e852ea7e19228003b9e2a78c7c646",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681328323089931"],
|
||||
sha256 = "b810de53d7ccf991b9c70fcdf7e88b5c3f2942ae766436f22be48159b6a7e687",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681849488227617"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
|
||||
sha256 = "69bc95af5b783b510ec1842d6fb9594254907d8e1334799c5753164878a7dcac",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681328325829340"],
|
||||
sha256 = "26d91147e5c6c8a92e0a4ebf59599068a3cff6108847b793ef33ac23e98eddb9",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681849491546937"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js",
|
||||
sha256 = "88a0176cc80d6a1eb175a5105df705cf8b8684cf13f6db0a264af0b67b65a22a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681328328330829"],
|
||||
sha256 = "b38e37b3024692558eaaba159921fedd3297d1a09bba1c16a06fed327845b0bd",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681849494099698"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm",
|
||||
sha256 = "1cc0c3db7d252801be4b090d8bbba61f308cc3dd5efe197319581d3af29495c7",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681328331085637"],
|
||||
sha256 = "6a8e73d2e926565046e16adf1748f0f8ec5135fafe7eb8b9c83892e64c1a449a",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681849496451970"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_text_wasm_internal_js",
|
||||
sha256 = "d9cd100b6d330d36f7749fe5fc64a2cdd0abb947a0376e6140784cfb0361a4e2",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681328333442454"],
|
||||
sha256 = "785cba67b623b1dc66dc3621e97fd6b30edccbb408184a3094d0aa68ddd5becb",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681849498746265"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
|
||||
sha256 = "30a2fcca630bdad6e99173ea7d0d8c5d7086aedf393d0159fa05bf9d08d4ff65",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681328335803336"],
|
||||
sha256 = "a858b8a2e8b40e9c936b66566c5aefd396536c4e936459ab9ae7e239621adc14",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681849501370461"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js",
|
||||
sha256 = "70ca2bd15c56e0ce7bb10ff2188b4a1f9eafbb657eb9424e4cab8d7b29179871",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681328338162884"],
|
||||
sha256 = "5292f1442d5e5c037e7cffb78a8c2d71255348ca2c3bd759b314bdbedd5590c2",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681849503379116"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm",
|
||||
sha256 = "8221b385905f36a769d7731a0adbe18b681bcb873561890429ca84278c67c3fd",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681328340808115"],
|
||||
sha256 = "e44b48ab29ee1d8befec804e9a63445c56266b679d19fb476d556ca621f0e493",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681849505997020"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
|
||||
sha256 = "07692acd8202adafebd35dbcd7e2b8e88a76d4a0e6b9229cb3cad59503eeddc7",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681328343147709"],
|
||||
sha256 = "205855eba70464a92b9d00e90acac15c51a9f76192f900e697304ac6dea8f714",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681849508414277"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
|
||||
sha256 = "03bf553fa6a768b0d70103a5e7d835b6b37371ff44e201c3392f22e0879737c3",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681328345605574"],
|
||||
sha256 = "c0cbd0df3adb2a9cd1331d14f522d2bae9f8adc9f1b35f92cbbc4b782b190cef",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681849510936608"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js",
|
||||
sha256 = "36697be14f921985eac15d1447ec8a260817b05ade1c9bb3ca7e906e0f047ec0",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681328348025082"],
|
||||
sha256 = "0969812de4d3573198fa2eba4f5b0a7e97e98f97bd4215d876543f4925e57b84",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681849513292639"],
|
||||
)
|
||||
|
||||
http_file(
|
||||
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm",
|
||||
sha256 = "103fb145438d61cfecb2e8db3f06b43a5d77a7e3fcea940437fe272227cf2592",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681328350709881"],
|
||||
sha256 = "f2ab62c3f8dabab0a573dadf5c105ff81a03c29c70f091f8cf273ae030c0a86f",
|
||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681849515999000"],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user