Merge branch 'google:master' into pose-landmarker-python
This commit is contained in:
commit
39742b6641
|
@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
} else if (packet_options.has_string_value()) {
|
} else if (packet_options.has_string_value()) {
|
||||||
packet.Set<std::string>();
|
packet.Set<std::string>();
|
||||||
} else if (packet_options.has_uint64_value()) {
|
} else if (packet_options.has_uint64_value()) {
|
||||||
packet.Set<uint64>();
|
packet.Set<uint64_t>();
|
||||||
} else if (packet_options.has_classification_list_value()) {
|
} else if (packet_options.has_classification_list_value()) {
|
||||||
packet.Set<ClassificationList>();
|
packet.Set<ClassificationList>();
|
||||||
} else if (packet_options.has_landmark_list_value()) {
|
} else if (packet_options.has_landmark_list_value()) {
|
||||||
|
@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase {
|
||||||
} else if (packet_options.has_string_value()) {
|
} else if (packet_options.has_string_value()) {
|
||||||
packet.Set(MakePacket<std::string>(packet_options.string_value()));
|
packet.Set(MakePacket<std::string>(packet_options.string_value()));
|
||||||
} else if (packet_options.has_uint64_value()) {
|
} else if (packet_options.has_uint64_value()) {
|
||||||
packet.Set(MakePacket<uint64>(packet_options.uint64_value()));
|
packet.Set(MakePacket<uint64_t>(packet_options.uint64_value()));
|
||||||
} else if (packet_options.has_classification_list_value()) {
|
} else if (packet_options.has_classification_list_value()) {
|
||||||
packet.Set(MakePacket<ClassificationList>(
|
packet.Set(MakePacket<ClassificationList>(
|
||||||
packet_options.classification_list_value()));
|
packet_options.classification_list_value()));
|
||||||
|
|
|
@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this when ALLOW/DISALLOW input is provided as a side packet.
|
// Use this when ALLOW/DISALLOW input is provided as a side packet.
|
||||||
void RunTimeStep(int64 timestamp, bool stream_payload) {
|
void RunTimeStep(int64_t timestamp, bool stream_payload) {
|
||||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||||
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
|
MakePacket<bool>(stream_payload).At(Timestamp(timestamp)));
|
||||||
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed.";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use this when ALLOW/DISALLOW input is provided as an input stream.
|
// Use this when ALLOW/DISALLOW input is provided as an input stream.
|
||||||
void RunTimeStep(int64 timestamp, const std::string& control_tag,
|
void RunTimeStep(int64_t timestamp, const std::string& control_tag,
|
||||||
bool control) {
|
bool control) {
|
||||||
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
runner_->MutableInputs()->Get("", 0).packets.push_back(
|
||||||
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
MakePacket<bool>(true).At(Timestamp(timestamp)));
|
||||||
|
@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) {
|
||||||
}
|
}
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) {
|
||||||
}
|
}
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) {
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) {
|
||||||
)");
|
)");
|
||||||
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true));
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, true);
|
RunTimeStep(kTimestampValue0, true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, false);
|
RunTimeStep(kTimestampValue1, false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) {
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
RunTimeStep(kTimestampValue1, "ALLOW", false);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) {
|
||||||
output_stream: "test_output"
|
output_stream: "test_output"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
RunTimeStep(kTimestampValue2, "DISALLOW", true);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
RunTimeStep(kTimestampValue3, "DISALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets;
|
||||||
|
@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
RunTimeStep(kTimestampValue0, "ALLOW", false);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
RunTimeStep(kTimestampValue1, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
RunTimeStep(kTimestampValue2, "ALLOW", true);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
RunTimeStep(kTimestampValue3, "ALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
RunTimeStep(kTimestampValue0, "DISALLOW", true);
|
||||||
constexpr int64 kTimestampValue1 = 43;
|
constexpr int64_t kTimestampValue1 = 43;
|
||||||
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
RunTimeStep(kTimestampValue1, "DISALLOW", false);
|
||||||
constexpr int64 kTimestampValue2 = 44;
|
constexpr int64_t kTimestampValue2 = 44;
|
||||||
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
RunTimeStep(kTimestampValue2, "DISALLOW", false);
|
||||||
constexpr int64 kTimestampValue3 = 45;
|
constexpr int64_t kTimestampValue3 = 45;
|
||||||
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
RunTimeStep(kTimestampValue3, "DISALLOW", true);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
RunTimeStep(kTimestampValue0, "DISALLOW", false);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) {
|
||||||
output_stream: "STATE_CHANGE:state_changed"
|
output_stream: "STATE_CHANGE:state_changed"
|
||||||
)");
|
)");
|
||||||
|
|
||||||
constexpr int64 kTimestampValue0 = 42;
|
constexpr int64_t kTimestampValue0 = 42;
|
||||||
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
RunTimeStep(kTimestampValue0, "ALLOW", true);
|
||||||
|
|
||||||
const std::vector<Packet>& output =
|
const std::vector<Packet>& output =
|
||||||
|
|
|
@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator);
|
||||||
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
|
using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>;
|
||||||
REGISTER_CALCULATOR(StringToUintCalculator);
|
REGISTER_CALCULATOR(StringToUintCalculator);
|
||||||
|
|
||||||
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>;
|
using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>;
|
||||||
REGISTER_CALCULATOR(StringToInt32Calculator);
|
REGISTER_CALCULATOR(StringToInt32Calculator);
|
||||||
|
|
||||||
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>;
|
using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>;
|
||||||
REGISTER_CALCULATOR(StringToUint32Calculator);
|
REGISTER_CALCULATOR(StringToUint32Calculator);
|
||||||
|
|
||||||
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>;
|
using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>;
|
||||||
REGISTER_CALCULATOR(StringToInt64Calculator);
|
REGISTER_CALCULATOR(StringToInt64Calculator);
|
||||||
|
|
||||||
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>;
|
using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>;
|
||||||
REGISTER_CALCULATOR(StringToUint64Calculator);
|
REGISTER_CALCULATOR(StringToUint64Calculator);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
|
||||||
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
|
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
|
||||||
frame_ptr->Height(), frame_ptr->WidthStep(),
|
frame_ptr->Height(), frame_ptr->WidthStep(),
|
||||||
const_cast<uint8_t*>(frame_ptr->PixelData()),
|
const_cast<uint8_t*>(frame_ptr->PixelData()),
|
||||||
[](uint8* data){});
|
[](uint8_t* data){});
|
||||||
ASSIGN_OR_RETURN(auto result,
|
ASSIGN_OR_RETURN(auto result,
|
||||||
runner->Run(image_frame, matrix, size, border_mode));
|
runner->Run(image_frame, matrix, size, border_mode));
|
||||||
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
|
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
|
||||||
|
|
|
@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
||||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
|
ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) {
|
||||||
// Record the most recent first kept timestamp on any stream.
|
// Record the most recent first kept timestamp on any stream.
|
||||||
for (const auto& stream : input_stream_managers_) {
|
for (const auto& stream : input_stream_managers_) {
|
||||||
int32 queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_)
|
||||||
? target_queue_size_
|
? target_queue_size_
|
||||||
: trigger_queue_size_ - 1;
|
: trigger_queue_size_ - 1;
|
||||||
if (stream->QueueSize() > queue_size) {
|
if (stream->QueueSize() > queue_size) {
|
||||||
kept_timestamp_ = std::max(
|
kept_timestamp_ = std::max(
|
||||||
kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
|
kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1)
|
||||||
|
@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int32 trigger_queue_size_;
|
int32_t trigger_queue_size_;
|
||||||
int32 target_queue_size_;
|
int32_t target_queue_size_;
|
||||||
bool fixed_min_size_;
|
bool fixed_min_size_;
|
||||||
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
|
// Indicates that GetNodeReadiness has returned kReadyForProcess once, and
|
||||||
// the corresponding call to FillInputSet has not yet completed.
|
// the corresponding call to FillInputSet has not yet completed.
|
||||||
|
|
|
@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal(
|
||||||
// TODO: Investigate this option in more detail, esp. on Safari.
|
// TODO: Investigate this option in more detail, esp. on Safari.
|
||||||
attrs.preserveDrawingBuffer = 0;
|
attrs.preserveDrawingBuffer = 0;
|
||||||
|
|
||||||
// Since the Emscripten canvas target finding function is visible from here,
|
// Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
|
||||||
// we hijack findCanvasEventTarget directly for enforcing old Module.canvas
|
// looks for our #canvas target in Module.canvas, where we expect it to be.
|
||||||
// behavior if the user desires, falling back to the new DOM element CSS
|
// -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
|
||||||
// selector behavior next if that is specified, and finally just allowing the
|
// event target behavior, but it was never supposed to be tapping into our
|
||||||
// lookup to proceed on a null target.
|
// canvas anyways. See b/278155946 for more background.
|
||||||
// TODO: Ensure this works with all options (in particular,
|
EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; });
|
||||||
// multithreading options, like the special-case combination of USE_PTHREADS
|
|
||||||
// and OFFSCREEN_FRAMEBUFFER)
|
|
||||||
// clang-format off
|
|
||||||
EM_ASM(
|
|
||||||
let init_once = true;
|
|
||||||
if (init_once) {
|
|
||||||
const cachedFindCanvasEventTarget = findCanvasEventTarget;
|
|
||||||
|
|
||||||
if (typeof cachedFindCanvasEventTarget !== 'function') {
|
|
||||||
if (typeof console !== 'undefined') {
|
|
||||||
console.error('Expected Emscripten global function '
|
|
||||||
+ '"findCanvasEventTarget" not found. WebGL context creation '
|
|
||||||
+ 'may fail.');
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
findCanvasEventTarget = function(target) {
|
|
||||||
if (target == 0) {
|
|
||||||
if (Module && Module.canvas) {
|
|
||||||
return Module.canvas;
|
|
||||||
} else if (Module && Module.canvasCssSelector) {
|
|
||||||
return cachedFindCanvasEventTarget(Module.canvasCssSelector);
|
|
||||||
}
|
|
||||||
if (typeof console !== 'undefined') {
|
|
||||||
console.warn('Module properties canvas and canvasCssSelector not ' +
|
|
||||||
'found during WebGL context creation.');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We still go through with the find attempt, although for most use
|
|
||||||
// cases it will not succeed, just in case the user does want to fall-
|
|
||||||
// back.
|
|
||||||
return cachedFindCanvasEventTarget(target);
|
|
||||||
}; // NOLINT: Necessary semicolon.
|
|
||||||
init_once = false;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
|
EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle =
|
||||||
emscripten_webgl_create_context(nullptr, &attrs);
|
emscripten_webgl_create_context("#canvas", &attrs);
|
||||||
|
|
||||||
// Check for failure
|
// Check for failure
|
||||||
if (context_handle <= 0) {
|
if (context_handle <= 0) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create(
|
||||||
int actual_ws = image_frame.WidthStep();
|
int actual_ws = image_frame.WidthStep();
|
||||||
int alignment = 0;
|
int alignment = 0;
|
||||||
std::unique_ptr<ImageFrame> temp;
|
std::unique_ptr<ImageFrame> temp;
|
||||||
const uint8* data = image_frame.PixelData();
|
const uint8_t* data = image_frame.PixelData();
|
||||||
|
|
||||||
// Let's see if the pixel data is tightly aligned to one of the alignments
|
// Let's see if the pixel data is tightly aligned to one of the alignments
|
||||||
// supported by OpenGL, preferring 4 if possible since it's the default.
|
// supported by OpenGL, preferring 4 if possible since it's the default.
|
||||||
|
|
|
@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
||||||
GpuBufferFormat format) {
|
GpuBufferFormat format) {
|
||||||
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
|
libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format);
|
||||||
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
|
int y_stride = std::ceil(1.0f * width / kDefaultDataAligment);
|
||||||
auto y_data = std::make_unique<uint8[]>(y_stride * height);
|
auto y_data = std::make_unique<uint8_t[]>(y_stride * height);
|
||||||
switch (fourcc) {
|
switch (fourcc) {
|
||||||
case libyuv::FOURCC_NV12:
|
case libyuv::FOURCC_NV12:
|
||||||
case libyuv::FOURCC_NV21: {
|
case libyuv::FOURCC_NV21: {
|
||||||
|
@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
||||||
int uv_width = 2 * std::ceil(0.5f * width);
|
int uv_width = 2 * std::ceil(0.5f * width);
|
||||||
int uv_height = std::ceil(0.5f * height);
|
int uv_height = std::ceil(0.5f * height);
|
||||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||||
auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||||
yuv_image_ = std::make_shared<YUVImage>(
|
yuv_image_ = std::make_shared<YUVImage>(
|
||||||
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
|
fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride,
|
||||||
nullptr, 0, width, height);
|
nullptr, 0, width, height);
|
||||||
|
@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height,
|
||||||
int uv_width = std::ceil(0.5f * width);
|
int uv_width = std::ceil(0.5f * width);
|
||||||
int uv_height = std::ceil(0.5f * height);
|
int uv_height = std::ceil(0.5f * height);
|
||||||
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment);
|
||||||
auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||||
auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height);
|
auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height);
|
||||||
yuv_image_ = std::make_shared<YUVImage>(
|
yuv_image_ = std::make_shared<YUVImage>(
|
||||||
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
|
fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride,
|
||||||
std::move(v_data), uv_stride, width, height);
|
std::move(v_data), uv_stride, width, height);
|
||||||
|
|
|
@ -16,6 +16,7 @@ import csv
|
||||||
import filecmp
|
import filecmp
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier
|
||||||
from mediapipe.tasks.python.test import test_utils
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip('b/275624089')
|
||||||
class TextClassifierTest(tf.test.TestCase):
|
class TextClassifierTest(tf.test.TestCase):
|
||||||
|
|
||||||
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
_AVERAGE_WORD_EMBEDDING_JSON_FILE = (
|
||||||
|
|
|
@ -175,11 +175,7 @@ py_test(
|
||||||
data = [":testdata"],
|
data = [":testdata"],
|
||||||
tags = ["requires-net:external"],
|
tags = ["requires-net:external"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":object_detector_import",
|
||||||
":hyperparameters",
|
|
||||||
":model_spec",
|
|
||||||
":object_detector",
|
|
||||||
":object_detector_options",
|
|
||||||
"//mediapipe/tasks/python/test:test_utils",
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,11 +19,7 @@ from unittest import mock as unittest_mock
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import dataset
|
from mediapipe.model_maker.python.vision import object_detector
|
||||||
from mediapipe.model_maker.python.vision.object_detector import hyperparameters
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import model_spec as ms
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector
|
|
||||||
from mediapipe.model_maker.python.vision.object_detector import object_detector_options
|
|
||||||
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
from mediapipe.tasks.python.test import test_utils as task_test_utils
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
dataset_folder = task_test_utils.get_test_data_path('coco_data')
|
||||||
cache_dir = self.create_tempdir()
|
cache_dir = self.create_tempdir()
|
||||||
self.data = dataset.Dataset.from_coco_folder(
|
self.data = object_detector.Dataset.from_coco_folder(
|
||||||
dataset_folder, cache_dir=cache_dir
|
dataset_folder, cache_dir=cache_dir
|
||||||
)
|
)
|
||||||
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
# Mock tempfile.gettempdir() to be unique for each test to avoid race
|
||||||
|
@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.addCleanup(mock_gettempdir.stop)
|
self.addCleanup(mock_gettempdir.stop)
|
||||||
|
|
||||||
def test_object_detector(self):
|
def test_object_detector(self):
|
||||||
hparams = hyperparameters.HParams(
|
hparams = object_detector.HParams(
|
||||||
epochs=1,
|
epochs=1,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
export_dir=self.create_tempdir(),
|
export_dir=self.create_tempdir(),
|
||||||
)
|
)
|
||||||
options = object_detector_options.ObjectDetectorOptions(
|
options = object_detector.ObjectDetectorOptions(
|
||||||
supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams
|
supported_model=object_detector.SupportedModels.MOBILENET_V2,
|
||||||
|
hparams=hparams,
|
||||||
)
|
)
|
||||||
# Test `create``
|
# Test `create``
|
||||||
model = object_detector.ObjectDetector.create(
|
model = object_detector.ObjectDetector.create(
|
||||||
|
@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
self.assertGreater(os.path.getsize(output_metadata_file), 0)
|
||||||
|
|
||||||
# Test `quantization_aware_training`
|
# Test `quantization_aware_training`
|
||||||
qat_hparams = hyperparameters.QATHParams(
|
qat_hparams = object_detector.QATHParams(
|
||||||
learning_rate=0.9,
|
learning_rate=0.9,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
epochs=1,
|
epochs=1,
|
||||||
|
|
|
@ -24,8 +24,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
void FrameAnnotationTracker::AddDetectionResult(
|
void FrameAnnotationTracker::AddDetectionResult(
|
||||||
const FrameAnnotation& frame_annotation) {
|
const FrameAnnotation& frame_annotation) {
|
||||||
const int64 time_us =
|
const int64_t time_us =
|
||||||
static_cast<int64>(std::round(frame_annotation.timestamp()));
|
static_cast<int64_t>(std::round(frame_annotation.timestamp()));
|
||||||
for (const auto& object_annotation : frame_annotation.annotations()) {
|
for (const auto& object_annotation : frame_annotation.annotations()) {
|
||||||
detected_objects_[time_us + object_annotation.object_id()] =
|
detected_objects_[time_us + object_annotation.object_id()] =
|
||||||
object_annotation;
|
object_annotation;
|
||||||
|
@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult(
|
||||||
absl::flat_hash_set<int>* cancel_object_ids) {
|
absl::flat_hash_set<int>* cancel_object_ids) {
|
||||||
CHECK(cancel_object_ids != nullptr);
|
CHECK(cancel_object_ids != nullptr);
|
||||||
FrameAnnotation frame_annotation;
|
FrameAnnotation frame_annotation;
|
||||||
std::vector<int64> keys_to_be_deleted;
|
std::vector<int64_t> keys_to_be_deleted;
|
||||||
for (const auto& detected_obj : detected_objects_) {
|
for (const auto& detected_obj : detected_objects_) {
|
||||||
const int object_id = detected_obj.second.object_id();
|
const int object_id = detected_obj.second.object_id();
|
||||||
if (cancel_object_ids->contains(object_id)) {
|
if (cancel_object_ids->contains(object_id)) {
|
||||||
|
|
|
@ -78,6 +78,7 @@ cc_library(
|
||||||
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
hdrs = ["mediapipe_builtin_op_resolver.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
|
"//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite",
|
||||||
|
"//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite",
|
||||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup",
|
||||||
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash",
|
||||||
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
|
#include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h"
|
||||||
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
|
||||||
|
@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
|
||||||
AddCustom("KmeansEmbeddingLookup",
|
AddCustom("KmeansEmbeddingLookup",
|
||||||
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
mediapipe::tflite_operations::Register_KmeansEmbeddingLookup());
|
||||||
// For the UniversalSentenceEncoder model.
|
// For the UniversalSentenceEncoder model.
|
||||||
|
AddCustom("TFSentencepieceTokenizeOp",
|
||||||
|
mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER());
|
||||||
AddCustom("RaggedTensorToTensor",
|
AddCustom("RaggedTensorToTensor",
|
||||||
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
|
mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR());
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,13 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = glob([
|
||||||
|
"testdata/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
name = "config_fbs",
|
name = "config_fbs",
|
||||||
srcs = ["config.fbs"],
|
srcs = ["config.fbs"],
|
||||||
|
@ -80,3 +87,86 @@ cc_test(
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "sentencepiece_constants",
|
||||||
|
hdrs = ["sentencepiece_constants.h"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "model_converter",
|
||||||
|
srcs = [
|
||||||
|
"model_converter.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"model_converter.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":config",
|
||||||
|
":double_array_trie_builder",
|
||||||
|
":encoder_config",
|
||||||
|
":sentencepiece_constants",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_sentencepiece//src:sentencepiece_model_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "optimized_encoder",
|
||||||
|
srcs = [
|
||||||
|
"optimized_encoder.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"optimized_encoder.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":double_array_trie",
|
||||||
|
":encoder_config",
|
||||||
|
":utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "sentencepiece_tokenizer_tflite",
|
||||||
|
srcs = ["sentencepiece_tokenizer_tflite.cc"],
|
||||||
|
hdrs = ["sentencepiece_tokenizer_tflite.h"],
|
||||||
|
visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
deps =
|
||||||
|
[
|
||||||
|
":optimized_encoder",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "optimized_encoder_test",
|
||||||
|
srcs = [
|
||||||
|
"optimized_encoder_test.cc",
|
||||||
|
],
|
||||||
|
data = [
|
||||||
|
":testdata",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":double_array_trie_builder",
|
||||||
|
":encoder_config",
|
||||||
|
":model_converter",
|
||||||
|
":optimized_encoder",
|
||||||
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@com_google_sentencepiece//src:sentencepiece_cc_proto",
|
||||||
|
"@com_google_sentencepiece//src:sentencepiece_processor",
|
||||||
|
"@org_tensorflow//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h"
|
||||||
|
#include "src/sentencepiece_model.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
|
||||||
|
DecodePrecompiledCharsmap(
|
||||||
|
const ::sentencepiece::NormalizerSpec& normalizer_spec) {
|
||||||
|
// This function "undoes" encoding done by
|
||||||
|
// sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
|
||||||
|
const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
|
||||||
|
const uint32_t trie_size =
|
||||||
|
*reinterpret_cast<const uint32_t*>(precompiled_map);
|
||||||
|
const uint32_t* trie_ptr =
|
||||||
|
reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
|
||||||
|
const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
|
||||||
|
precompiled_map + sizeof(uint32_t) + trie_size);
|
||||||
|
const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
|
||||||
|
sizeof(uint32_t) - trie_size;
|
||||||
|
return std::make_tuple(
|
||||||
|
std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
|
||||||
|
std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||||
|
const std::string& model_config_str, int encoding_offset) {
|
||||||
|
::sentencepiece::ModelProto model_config;
|
||||||
|
if (!model_config.ParseFromString(model_config_str)) {
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"Invalid configuration, can't parse SentencePiece model config " +
|
||||||
|
model_config.InitializationErrorString());
|
||||||
|
}
|
||||||
|
// Convert sentencepieces.
|
||||||
|
std::vector<std::string> pieces;
|
||||||
|
pieces.reserve(model_config.pieces_size());
|
||||||
|
std::vector<float> scores;
|
||||||
|
scores.reserve(model_config.pieces_size());
|
||||||
|
std::vector<int> ids;
|
||||||
|
ids.reserve(model_config.pieces_size());
|
||||||
|
float min_score = 0.0;
|
||||||
|
int index = 0;
|
||||||
|
for (const auto& piece : model_config.pieces()) {
|
||||||
|
switch (piece.type()) {
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
|
||||||
|
pieces.push_back(piece.piece());
|
||||||
|
ids.push_back(index);
|
||||||
|
if (piece.score() < min_score) {
|
||||||
|
min_score = piece.score();
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
|
||||||
|
case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
|
||||||
|
// Ignore unknown and control codes.
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
|
||||||
|
piece.piece());
|
||||||
|
}
|
||||||
|
scores.push_back(piece.score());
|
||||||
|
++index;
|
||||||
|
}
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
|
||||||
|
const auto pieces_score_vector = builder.CreateVector(scores);
|
||||||
|
TrieBuilder pieces_trie_builder(builder);
|
||||||
|
pieces_trie_builder.add_nodes(pieces_trie_vector);
|
||||||
|
const auto pieces_trie_fbs = pieces_trie_builder.Finish();
|
||||||
|
|
||||||
|
// Converting normalization.
|
||||||
|
const auto normalization =
|
||||||
|
DecodePrecompiledCharsmap(model_config.normalizer_spec());
|
||||||
|
const auto normalization_trie = std::get<0>(normalization);
|
||||||
|
const auto normalization_strings = std::get<1>(normalization);
|
||||||
|
const auto normalization_trie_vector =
|
||||||
|
builder.CreateVector(normalization_trie);
|
||||||
|
TrieBuilder normalization_trie_builder(builder);
|
||||||
|
normalization_trie_builder.add_nodes(normalization_trie_vector);
|
||||||
|
const auto normalization_trie_fbs = normalization_trie_builder.Finish();
|
||||||
|
const auto normalization_strings_fbs =
|
||||||
|
builder.CreateVector(normalization_strings);
|
||||||
|
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
|
||||||
|
ecb.add_start_code(model_config.trainer_spec().bos_id());
|
||||||
|
ecb.add_end_code(model_config.trainer_spec().eos_id());
|
||||||
|
ecb.add_unknown_code(model_config.trainer_spec().unk_id());
|
||||||
|
ecb.add_unknown_penalty(min_score - kUnkPenalty);
|
||||||
|
ecb.add_encoding_offset(encoding_offset);
|
||||||
|
ecb.add_pieces(pieces_trie_fbs);
|
||||||
|
ecb.add_pieces_scores(pieces_score_vector);
|
||||||
|
ecb.add_remove_extra_whitespaces(
|
||||||
|
model_config.normalizer_spec().remove_extra_whitespaces());
|
||||||
|
ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
|
||||||
|
ecb.add_escape_whitespaces(
|
||||||
|
model_config.normalizer_spec().escape_whitespaces());
|
||||||
|
ecb.add_normalized_prefixes(normalization_trie_fbs);
|
||||||
|
ecb.add_normalized_replacements(normalization_strings_fbs);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
|
||||||
|
builder.GetSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ConvertSentencepieceModel(const std::string& model_string) {
|
||||||
|
const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
|
||||||
|
assert(result.status().ok());
|
||||||
|
return result.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,33 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/statusor.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// Converts Sentencepiece configuration to flatbuffer format.
|
||||||
|
// encoding_offset is used by some encoders that combine different encodings.
|
||||||
|
absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
|
||||||
|
const std::string& model_config_str, int encoding_offset = 0);
|
||||||
|
std::string ConvertSentencepieceModel(const std::string& model_string);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
|
|
@ -0,0 +1,236 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||||
|
|
||||||
|
template <typename processing_callback>
|
||||||
|
std::tuple<std::string, std::vector<int>> process_string(
|
||||||
|
const std::string& input, const std::vector<int>& offsets,
|
||||||
|
const processing_callback& pc) {
|
||||||
|
std::string result_string;
|
||||||
|
result_string.reserve(input.size());
|
||||||
|
std::vector<int> result_offsets;
|
||||||
|
result_offsets.reserve(offsets.size());
|
||||||
|
for (int i = 0, j = 0; i < input.size();) {
|
||||||
|
auto result = pc(input.data() + i, input.size() - i);
|
||||||
|
auto consumed = std::get<0>(result);
|
||||||
|
auto new_string = std::get<1>(result);
|
||||||
|
if (consumed == 0) {
|
||||||
|
// Skip the current byte and move forward.
|
||||||
|
result_string.push_back(input[i]);
|
||||||
|
result_offsets.push_back(offsets[j]);
|
||||||
|
i++;
|
||||||
|
j++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result_string.append(new_string.data(), new_string.length());
|
||||||
|
for (int i = 0; i < new_string.length(); ++i) {
|
||||||
|
result_offsets.push_back(offsets[j]);
|
||||||
|
}
|
||||||
|
j += consumed;
|
||||||
|
i += consumed;
|
||||||
|
}
|
||||||
|
return std::make_tuple(result_string, result_offsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline char is_whitespace(char c) {
|
||||||
|
return c == ' ' || c == '\t' || c == '\r' || c == '\n';
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data,
|
||||||
|
int len) {
|
||||||
|
if (len == 0 || !is_whitespace(*data)) {
|
||||||
|
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
}
|
||||||
|
int num_consumed = 1;
|
||||||
|
for (; num_consumed < len && is_whitespace(data[num_consumed]);
|
||||||
|
++num_consumed) {
|
||||||
|
}
|
||||||
|
return num_consumed > 1
|
||||||
|
? std::make_tuple(num_consumed, utils::string_view(" ", 1))
|
||||||
|
: std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<int, utils::string_view> find_replacement(
|
||||||
|
const char* data, int len, const DoubleArrayTrie& dat,
|
||||||
|
const flatbuffers::Vector<int8_t>& replacements) {
|
||||||
|
const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len));
|
||||||
|
if (!max_match.empty()) {
|
||||||
|
// Because flatbuffer byte is signed char which is not the same as char,
|
||||||
|
// there is the reinterpret_cast here.
|
||||||
|
const char* replaced_string_ptr =
|
||||||
|
reinterpret_cast<const char*>(replacements.data() + max_match.id);
|
||||||
|
return std::make_tuple(max_match.match_length,
|
||||||
|
utils::string_view(replaced_string_ptr));
|
||||||
|
}
|
||||||
|
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||||
|
const std::string& in_string, const EncoderConfig& config) {
|
||||||
|
std::vector<int> output_offsets;
|
||||||
|
std::string result = in_string;
|
||||||
|
output_offsets.reserve(in_string.length());
|
||||||
|
for (int i = 0; i < in_string.length(); ++i) {
|
||||||
|
output_offsets.push_back(i);
|
||||||
|
}
|
||||||
|
if (in_string.empty()) {
|
||||||
|
return std::make_tuple(result, output_offsets);
|
||||||
|
}
|
||||||
|
if (config.add_dummy_prefix()) {
|
||||||
|
result.insert(result.begin(), ' ');
|
||||||
|
output_offsets.insert(output_offsets.begin(), 0);
|
||||||
|
}
|
||||||
|
// Greedely replace normalized_prefixes with normalized_replacements
|
||||||
|
if (config.normalized_prefixes() != nullptr &&
|
||||||
|
config.normalized_replacements() != nullptr) {
|
||||||
|
const DoubleArrayTrie normalized_prefixes_matcher(
|
||||||
|
config.normalized_prefixes()->nodes());
|
||||||
|
const auto norm_replace = [&config, &normalized_prefixes_matcher](
|
||||||
|
const char* data, int len) {
|
||||||
|
return find_replacement(data, len, normalized_prefixes_matcher,
|
||||||
|
*config.normalized_replacements());
|
||||||
|
};
|
||||||
|
std::tie(result, output_offsets) =
|
||||||
|
process_string(result, output_offsets, norm_replace);
|
||||||
|
}
|
||||||
|
if (config.remove_extra_whitespaces()) {
|
||||||
|
std::tie(result, output_offsets) =
|
||||||
|
process_string(result, output_offsets, remove_extra_whitespaces);
|
||||||
|
if (!result.empty() && is_whitespace(result.back())) {
|
||||||
|
result.pop_back();
|
||||||
|
output_offsets.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (config.escape_whitespaces()) {
|
||||||
|
const auto replace_whitespaces = [](const char* data, int len) {
|
||||||
|
if (len > 0 && is_whitespace(*data)) {
|
||||||
|
return std::make_tuple(1, utils::string_view(kSpaceSymbol));
|
||||||
|
}
|
||||||
|
return std::make_tuple(0, utils::string_view(nullptr, 0));
|
||||||
|
};
|
||||||
|
std::tie(result, output_offsets) =
|
||||||
|
process_string(result, output_offsets, replace_whitespaces);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(result, output_offsets);
|
||||||
|
}
|
||||||
|
|
||||||
|
EncoderResult EncodeNormalizedString(const std::string& str,
|
||||||
|
const std::vector<int>& offsets,
|
||||||
|
const EncoderConfig& config, bool add_bos,
|
||||||
|
bool add_eos, bool reverse) {
|
||||||
|
const DoubleArrayTrie piece_matcher(config.pieces()->nodes());
|
||||||
|
const flatbuffers::Vector<float>* piece_scores = config.pieces_scores();
|
||||||
|
const int unknown_code = config.unknown_code();
|
||||||
|
const float unknown_penalty = config.unknown_penalty();
|
||||||
|
struct LatticeElement {
|
||||||
|
float score = 0;
|
||||||
|
int code = -1;
|
||||||
|
int prev_position = -1;
|
||||||
|
LatticeElement(float score_, int code_, int prev_position_)
|
||||||
|
: score(score_), code(code_), prev_position(prev_position_) {}
|
||||||
|
LatticeElement() {}
|
||||||
|
};
|
||||||
|
const int length = str.length();
|
||||||
|
std::vector<LatticeElement> lattice(length + 1);
|
||||||
|
for (int i = 0; i < length; ++i) {
|
||||||
|
if (i > 0 && lattice[i].prev_position < 0) {
|
||||||
|
// This state is unreachable.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (unknown_code >= 0) {
|
||||||
|
// Put unknown code.
|
||||||
|
const float penalized_score = lattice[i].score + unknown_penalty;
|
||||||
|
const int pos = i + 1;
|
||||||
|
LatticeElement& current_element = lattice[pos];
|
||||||
|
if (current_element.prev_position < 0 ||
|
||||||
|
current_element.score < penalized_score) {
|
||||||
|
current_element = LatticeElement(
|
||||||
|
penalized_score, unknown_code,
|
||||||
|
// If the current state is already reached by unknown code, merge
|
||||||
|
// states.
|
||||||
|
lattice[i].code == unknown_code ? lattice[i].prev_position : i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto lattice_update = [&lattice, i,
|
||||||
|
piece_scores](const DoubleArrayTrie::Match& m) {
|
||||||
|
LatticeElement& target_element = lattice[i + m.match_length];
|
||||||
|
const float score = lattice[i].score + (*piece_scores)[m.id];
|
||||||
|
if (target_element.prev_position < 0 || target_element.score < score) {
|
||||||
|
target_element = LatticeElement(score, m.id, i);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
piece_matcher.IteratePrefixMatches(
|
||||||
|
utils::string_view(str.data() + i, length - i), lattice_update);
|
||||||
|
}
|
||||||
|
|
||||||
|
EncoderResult result;
|
||||||
|
if (add_eos) {
|
||||||
|
result.codes.push_back(config.end_code());
|
||||||
|
result.offsets.push_back(length);
|
||||||
|
}
|
||||||
|
if (lattice[length].prev_position >= 0) {
|
||||||
|
for (int pos = length; pos > 0;) {
|
||||||
|
auto code = lattice[pos].code;
|
||||||
|
if (code != config.unknown_code()) {
|
||||||
|
code += config.encoding_offset();
|
||||||
|
}
|
||||||
|
result.codes.push_back(code);
|
||||||
|
pos = lattice[pos].prev_position;
|
||||||
|
result.offsets.push_back(offsets[pos]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_bos) {
|
||||||
|
result.codes.push_back(config.start_code());
|
||||||
|
result.offsets.push_back(0);
|
||||||
|
}
|
||||||
|
if (!reverse) {
|
||||||
|
std::reverse(result.codes.begin(), result.codes.end());
|
||||||
|
std::reverse(result.offsets.begin(), result.offsets.end());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||||
|
bool add_bos, bool add_eos, bool reverse) {
|
||||||
|
// Get the config from the buffer.
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(config_buffer);
|
||||||
|
if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) {
|
||||||
|
EncoderResult result;
|
||||||
|
result.type = EncoderResultType::WRONG_CONFIG;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
std::string normalized_string;
|
||||||
|
std::vector<int> offsets;
|
||||||
|
std::tie(normalized_string, offsets) = NormalizeString(string, *config);
|
||||||
|
return EncodeNormalizedString(normalized_string, offsets, *config, add_bos,
|
||||||
|
add_eos, reverse);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,46 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
||||||
|
|
||||||
|
// Sentencepiece encoder optimized with memmapped model.
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 };
|
||||||
|
|
||||||
|
struct EncoderResult {
|
||||||
|
EncoderResultType type = EncoderResultType::SUCCESS;
|
||||||
|
std::vector<int> codes;
|
||||||
|
std::vector<int> offsets;
|
||||||
|
};
|
||||||
|
std::tuple<std::string, std::vector<int>> NormalizeString(
|
||||||
|
const std::string& in_string, const EncoderConfig& config);
|
||||||
|
|
||||||
|
// Encodes one string and returns ids and offsets. Takes the configuration as a
|
||||||
|
// type-erased buffer.
|
||||||
|
EncoderResult EncodeString(const std::string& string, const void* config_buffer,
|
||||||
|
bool add_bos, bool add_eos, bool reverse);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
|
|
@ -0,0 +1,171 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "absl/flags/flag.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
|
#include "mediapipe/framework/deps/file_path.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h"
|
||||||
|
#include "src/sentencepiece.pb.h"
|
||||||
|
#include "src/sentencepiece_processor.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
tensorflow::Status TFReadFileToString(const std::string& filepath,
|
||||||
|
std::string* data) {
|
||||||
|
return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath,
|
||||||
|
data);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status StdReadFileToString(const std::string& filepath,
|
||||||
|
std::string* data) {
|
||||||
|
std::ifstream infile(filepath);
|
||||||
|
if (!infile.is_open()) {
|
||||||
|
return absl::NotFoundError(
|
||||||
|
absl::StrFormat("Error when opening %s", filepath));
|
||||||
|
}
|
||||||
|
std::string contents((std::istreambuf_iterator<char>(infile)),
|
||||||
|
(std::istreambuf_iterator<char>()));
|
||||||
|
data->append(contents);
|
||||||
|
infile.close();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::file::JoinPath;
|
||||||
|
|
||||||
|
static char kConfigFilePath[] =
|
||||||
|
"/mediapipe/tasks/cc/text/custom_ops/"
|
||||||
|
"sentencepiece/testdata/sentencepiece.model";
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, NormalizeStringWhitestpaces) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_remove_extra_whitespaces(true);
|
||||||
|
ecb.add_add_dummy_prefix(true);
|
||||||
|
ecb.add_escape_whitespaces(true);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("x y", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||||
|
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("\tx y\n", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y");
|
||||||
|
EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, NormalizeStringReplacement) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"};
|
||||||
|
const char norm_replacements[] = "A1\0A2\0A3\0A4";
|
||||||
|
const auto trie_vector =
|
||||||
|
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9}));
|
||||||
|
const auto norm_r = builder.CreateVector<int8_t>(
|
||||||
|
reinterpret_cast<const signed char*>(norm_replacements),
|
||||||
|
sizeof(norm_replacements));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto norm_p = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_remove_extra_whitespaces(false);
|
||||||
|
ecb.add_normalized_prefixes(norm_p);
|
||||||
|
ecb.add_normalized_replacements(norm_r);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("ABAABAAABAAAA", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, "A1BA2BA3BA4");
|
||||||
|
EXPECT_THAT(offsets,
|
||||||
|
::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) {
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA",
|
||||||
|
"X"};
|
||||||
|
const char norm_replacements[] = "A1\0A2\0A3\0A4\0 ";
|
||||||
|
const auto trie_vector =
|
||||||
|
builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12}));
|
||||||
|
const auto norm_r = builder.CreateVector<int8_t>(
|
||||||
|
reinterpret_cast<const signed char*>(norm_replacements),
|
||||||
|
sizeof(norm_replacements));
|
||||||
|
TrieBuilder trie_builder(builder);
|
||||||
|
trie_builder.add_nodes(trie_vector);
|
||||||
|
const auto norm_p = trie_builder.Finish();
|
||||||
|
EncoderConfigBuilder ecb(builder);
|
||||||
|
ecb.add_remove_extra_whitespaces(true);
|
||||||
|
ecb.add_normalized_prefixes(norm_p);
|
||||||
|
ecb.add_normalized_replacements(norm_r);
|
||||||
|
FinishEncoderConfigBuffer(builder, ecb.Finish());
|
||||||
|
const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer());
|
||||||
|
{
|
||||||
|
const auto result = NormalizeString("XXABAABAAABAAAA", *config);
|
||||||
|
const auto res_string = std::get<0>(result);
|
||||||
|
const auto offsets = std::get<1>(result);
|
||||||
|
EXPECT_EQ(res_string, " A1BA2BA3BA4");
|
||||||
|
EXPECT_THAT(offsets,
|
||||||
|
::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(OptimizedEncoder, ConfigConverter) {
|
||||||
|
std::string config;
|
||||||
|
auto status =
|
||||||
|
internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config);
|
||||||
|
ASSERT_TRUE(status.ok());
|
||||||
|
|
||||||
|
::sentencepiece::SentencePieceProcessor processor;
|
||||||
|
ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok());
|
||||||
|
const auto converted_model = ConvertSentencepieceModel(config);
|
||||||
|
const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95");
|
||||||
|
const auto encoded =
|
||||||
|
EncodeString(test_string, converted_model.data(), false, false, false);
|
||||||
|
ASSERT_EQ(encoded.codes.size(), encoded.offsets.size());
|
||||||
|
|
||||||
|
::sentencepiece::SentencePieceText reference_encoded;
|
||||||
|
ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok());
|
||||||
|
EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size());
|
||||||
|
for (int i = 0; i < encoded.codes.size(); ++i) {
|
||||||
|
EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id());
|
||||||
|
EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
|
@ -0,0 +1,38 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations::sentencepiece {
|
||||||
|
|
||||||
|
// The constant is copied from
|
||||||
|
// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
|
||||||
|
constexpr float kUnkPenalty = 10.0;
|
||||||
|
|
||||||
|
// These constants are copied from
|
||||||
|
// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
|
||||||
|
//
|
||||||
|
// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
|
||||||
|
constexpr char kSpaceSymbol[] = "\xe2\x96\x81";
|
||||||
|
|
||||||
|
// Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
|
||||||
|
// since this character can be useful both for user and
|
||||||
|
// developer. We can easily figure out that <unk> is emitted.
|
||||||
|
constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 ";
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations::sentencepiece
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
|
|
@ -0,0 +1,129 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h"
|
||||||
|
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/context.h"
|
||||||
|
#include "tensorflow/lite/kernels/internal/tensor.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations {
|
||||||
|
namespace sentencepiece::tokenizer {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::tflite::SetTensorToDynamic;
|
||||||
|
|
||||||
|
constexpr int kSPModelIndex = 0;
|
||||||
|
constexpr int kInputIndex = 1;
|
||||||
|
constexpr int kAddBOSInput = 4;
|
||||||
|
constexpr int kAddEOSInput = 5;
|
||||||
|
constexpr int kReverseInput = 6;
|
||||||
|
|
||||||
|
constexpr int kOutputValuesInd = 0;
|
||||||
|
constexpr int kOutputSplitsInd = 1;
|
||||||
|
|
||||||
|
TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
|
||||||
|
TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
|
||||||
|
int index = 0;
|
||||||
|
for (const int size : sizes) {
|
||||||
|
array_size->data[index++] = size;
|
||||||
|
}
|
||||||
|
return array_size;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Initializes text encoder object from serialized parameters.
|
||||||
|
void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/,
|
||||||
|
size_t /*length*/) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
void Free(TfLiteContext* /*context*/, void* /*buffer*/) {}
|
||||||
|
|
||||||
|
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
// TODO: Add checks for input and output tensors.
|
||||||
|
TfLiteTensor& output_values =
|
||||||
|
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||||
|
SetTensorToDynamic(&output_values);
|
||||||
|
|
||||||
|
TfLiteTensor& output_splits =
|
||||||
|
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||||
|
SetTensorToDynamic(&output_splits);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
const TfLiteTensor& model_tensor =
|
||||||
|
context->tensors[node->inputs->data[kSPModelIndex]];
|
||||||
|
const auto model_buffer_data = model_tensor.data.data;
|
||||||
|
const TfLiteTensor& input_text =
|
||||||
|
context->tensors[node->inputs->data[kInputIndex]];
|
||||||
|
|
||||||
|
const TfLiteTensor add_bos_tensor =
|
||||||
|
context->tensors[node->inputs->data[kAddBOSInput]];
|
||||||
|
const bool add_bos = add_bos_tensor.data.b[0];
|
||||||
|
const TfLiteTensor add_eos_tensor =
|
||||||
|
context->tensors[node->inputs->data[kAddEOSInput]];
|
||||||
|
const bool add_eos = add_eos_tensor.data.b[0];
|
||||||
|
const TfLiteTensor reverse_tensor =
|
||||||
|
context->tensors[node->inputs->data[kReverseInput]];
|
||||||
|
const bool reverse = reverse_tensor.data.b[0];
|
||||||
|
|
||||||
|
std::vector<int32> encoded;
|
||||||
|
std::vector<int32> splits;
|
||||||
|
const int num_strings = tflite::GetStringCount(&input_text);
|
||||||
|
for (int i = 0; i < num_strings; ++i) {
|
||||||
|
const auto strref = tflite::GetString(&input_text, i);
|
||||||
|
const auto res = EncodeString(std::string(strref.str, strref.len),
|
||||||
|
model_buffer_data, add_bos, add_eos, reverse);
|
||||||
|
TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS,
|
||||||
|
"Sentencepiece conversion failed");
|
||||||
|
std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded));
|
||||||
|
splits.emplace_back(encoded.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteTensor& output_values =
|
||||||
|
context->tensors[node->outputs->data[kOutputValuesInd]];
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
context->ResizeTensor(
|
||||||
|
context, &output_values,
|
||||||
|
CreateSizeArray({static_cast<int>(encoded.size())})));
|
||||||
|
int32_t* output_values_flat = output_values.data.i32;
|
||||||
|
std::copy(encoded.begin(), encoded.end(), output_values_flat);
|
||||||
|
TfLiteTensor& output_splits =
|
||||||
|
context->tensors[node->outputs->data[kOutputSplitsInd]];
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context, context->ResizeTensor(
|
||||||
|
context, &output_splits,
|
||||||
|
CreateSizeArray({static_cast<int>(splits.size() + 1)})));
|
||||||
|
int32_t* output_splits_flat = output_splits.data.i32;
|
||||||
|
*output_splits_flat = 0;
|
||||||
|
std::copy(splits.begin(), splits.end(), output_splits_flat + 1);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
} // namespace sentencepiece::tokenizer
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() {
|
||||||
|
static TfLiteRegistration r = {
|
||||||
|
sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free,
|
||||||
|
sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tflite_operations {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
|
||||||
|
|
||||||
|
} // namespace mediapipe::tflite_operations
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
|
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
BIN
mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
vendored
Normal file
Binary file not shown.
|
@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite";
|
||||||
// Embedding model with regex preprocessing.
|
// Embedding model with regex preprocessing.
|
||||||
constexpr char kRegexOneEmbeddingModel[] =
|
constexpr char kRegexOneEmbeddingModel[] =
|
||||||
"regex_one_embedding_with_metadata.tflite";
|
"regex_one_embedding_with_metadata.tflite";
|
||||||
|
constexpr char kUniversalSentenceEncoderModel[] =
|
||||||
|
"universal_sentence_encoder_qa_with_metadata.tflite";
|
||||||
|
|
||||||
// Tolerance for embedding vector coordinate values.
|
// Tolerance for embedding vector coordinate values.
|
||||||
constexpr float kEpsilon = 1e-4;
|
constexpr float kEpsilon = 1e-4;
|
||||||
|
@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) {
|
||||||
MP_ASSERT_OK(text_embedder->Close());
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) {
|
||||||
|
auto options = std::make_unique<TextEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||||
|
TextEmbedder::Create(std::move(options)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result0,
|
||||||
|
text_embedder->Embed("it's a charming and often affecting journey"));
|
||||||
|
ASSERT_EQ(result0.embeddings.size(), 1);
|
||||||
|
ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100);
|
||||||
|
ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon);
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result1, text_embedder->Embed("what a great and fantastic trip"));
|
||||||
|
ASSERT_EQ(result1.embeddings.size(), 1);
|
||||||
|
ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100);
|
||||||
|
ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon);
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||||
|
result1.embeddings[0]));
|
||||||
|
ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||||
auto options = std::make_unique<TextEmbedderOptions>();
|
auto options = std::make_unique<TextEmbedderOptions>();
|
||||||
options->base_options.model_asset_path =
|
options->base_options.model_asset_path =
|
||||||
|
@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) {
|
||||||
MP_ASSERT_OK(text_embedder->Close());
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) {
|
||||||
|
auto options = std::make_unique<TextEmbedderOptions>();
|
||||||
|
options->base_options.model_asset_path =
|
||||||
|
JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel);
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder,
|
||||||
|
TextEmbedder::Create(std::move(options)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextEmbedderResult result0,
|
||||||
|
text_embedder->Embed("When you go to this restaurant, they hold the "
|
||||||
|
"pancake upside-down before they hand it "
|
||||||
|
"to you. It's a great gimmick."));
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
TextEmbedderResult result1,
|
||||||
|
text_embedder->Embed(
|
||||||
|
"Let's make a plan to steal the declaration of independence."));
|
||||||
|
|
||||||
|
// Check cosine similarity.
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0],
|
||||||
|
result1.embeddings[0]));
|
||||||
|
EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy);
|
||||||
|
|
||||||
|
MP_ASSERT_OK(text_embedder->Close());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe::tasks::text::text_embedder
|
} // namespace mediapipe::tasks::text::text_embedder
|
||||||
|
|
|
@ -23,18 +23,12 @@ cc_library(
|
||||||
srcs = ["face_stylizer_graph.cc"],
|
srcs = ["face_stylizer_graph.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:image_cropping_calculator",
|
"//mediapipe/calculators/image:image_clone_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:image_cropping_calculator_cc_proto",
|
|
||||||
"//mediapipe/calculators/image:warp_affine_calculator",
|
|
||||||
"//mediapipe/calculators/image:warp_affine_calculator_cc_proto",
|
|
||||||
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
"//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensor:inference_calculator",
|
"//mediapipe/calculators/tensor:inference_calculator",
|
||||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||||
"//mediapipe/calculators/util:face_to_rect_calculator",
|
"//mediapipe/calculators/util:face_to_rect_calculator",
|
||||||
"//mediapipe/calculators/util:from_image_calculator",
|
|
||||||
"//mediapipe/calculators/util:inverse_matrix_calculator",
|
|
||||||
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
|
"//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/util:to_image_calculator",
|
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
@ -53,7 +47,6 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator",
|
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
|
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
||||||
|
|
|
@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The input image can be of any size with format RGB or RGBA.
|
// The input image can be of any size with format RGB or RGBA.
|
||||||
// When no face is detected on the input image, the method returns a
|
// When no face is detected on the input image, the method returns a
|
||||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||||
// face. To ensure that the output image has reasonable quality, the stylized
|
// face. The stylized output image size is the same as the model output size.
|
||||||
// output image size is the smaller of the model output size and the size of
|
|
||||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
|
||||||
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
|
absl::StatusOr<std::optional<mediapipe::Image>> Stylize(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
// When no face is detected on the input image, the method returns a
|
// When no face is detected on the input image, the method returns a
|
||||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||||
// face. To ensure that the output image has reasonable quality, the stylized
|
// face. The stylized output image size is the same as the model output size.
|
||||||
// output image size is the smaller of the model output size and the size of
|
|
||||||
// the 'region_of_interest' specified in 'image_processing_options'.
|
|
||||||
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
|
absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo(
|
||||||
mediapipe::Image image, int64_t timestamp_ms,
|
mediapipe::Image image, int64_t timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions> image_processing_options =
|
std::optional<core::ImageProcessingOptions> image_processing_options =
|
||||||
|
@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The "result_callback" provides:
|
// The "result_callback" provides:
|
||||||
// - When no face is detected on the input image, the method returns a
|
// - When no face is detected on the input image, the method returns a
|
||||||
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
// std::nullopt. Otherwise, returns the stylized image of the most visible
|
||||||
// face. To ensure that the output image has reasonable quality, the
|
// face. The stylized output image size is the same as the model output
|
||||||
// stylized output image size is the smaller of the model output size and
|
// size.
|
||||||
// the size of the 'region_of_interest' specified in
|
|
||||||
// 'image_processing_options'.
|
|
||||||
// - The input timestamp in milliseconds.
|
// - The input timestamp in milliseconds.
|
||||||
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
|
absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms,
|
||||||
std::optional<core::ImageProcessingOptions>
|
std::optional<core::ImageProcessingOptions>
|
||||||
|
|
|
@ -19,8 +19,7 @@ limitations under the License.
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
#include "mediapipe/calculators/core/split_vector_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/image/image_cropping_calculator.pb.h"
|
#include "mediapipe/calculators/image/image_clone_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
|
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
|
#include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h"
|
||||||
#include "mediapipe/framework/api2/builder.h"
|
#include "mediapipe/framework/api2/builder.h"
|
||||||
|
@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
||||||
image_in >> preprocessing.In(kImageTag);
|
image_in >> preprocessing.In(kImageTag);
|
||||||
face_rect >> preprocessing.In(kNormRectTag);
|
face_rect >> preprocessing.In(kNormRectTag);
|
||||||
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
auto preprocessed_tensors = preprocessing.Out(kTensorsTag);
|
||||||
auto transform_matrix = preprocessing.Out(kMatrixTag);
|
|
||||||
|
|
||||||
// Adds inference subgraph and connects its input stream to the output
|
// Adds inference subgraph and connects its input stream to the output
|
||||||
// tensors produced by the ImageToTensorCalculator.
|
// tensors produced by the ImageToTensorCalculator.
|
||||||
|
@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph {
|
||||||
model_output_tensors >> tensors_to_image.In(kTensorsTag);
|
model_output_tensors >> tensors_to_image.In(kTensorsTag);
|
||||||
auto tensor_image = tensors_to_image.Out(kImageTag);
|
auto tensor_image = tensors_to_image.Out(kImageTag);
|
||||||
|
|
||||||
auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator");
|
auto& image_converter = graph.AddNode("ImageCloneCalculator");
|
||||||
transform_matrix >> inverse_matrix.In(kMatrixTag);
|
image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>()
|
||||||
auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag);
|
.set_output_on_gpu(false);
|
||||||
|
tensor_image >> image_converter.In("");
|
||||||
|
|
||||||
auto& warp_affine = graph.AddNode("WarpAffineCalculator");
|
return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(),
|
||||||
auto& warp_affine_options =
|
|
||||||
warp_affine.GetOptions<WarpAffineCalculatorOptions>();
|
|
||||||
warp_affine_options.set_border_mode(
|
|
||||||
WarpAffineCalculatorOptions::BORDER_ZERO);
|
|
||||||
warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT);
|
|
||||||
tensor_image >> warp_affine.In(kImageTag);
|
|
||||||
inverse_transform_matrix >> warp_affine.In(kMatrixTag);
|
|
||||||
image_size >> warp_affine.In(kOutputSizeTag);
|
|
||||||
auto image_to_crop = warp_affine.Out(kImageTag);
|
|
||||||
|
|
||||||
// The following calculators are for cropping and resizing the output image
|
|
||||||
// based on the roi and the model output size. As the WarpAffineCalculator
|
|
||||||
// rotates the image based on the transform matrix, the rotation info in the
|
|
||||||
// rect proto is stripped to prevent the ImageCroppingCalculator from
|
|
||||||
// performing extra rotation.
|
|
||||||
auto& strip_rotation =
|
|
||||||
graph.AddNode("mediapipe.tasks.StripRotationCalculator");
|
|
||||||
face_rect >> strip_rotation.In(kNormRectTag);
|
|
||||||
auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag);
|
|
||||||
auto& from_image = graph.AddNode("FromImageCalculator");
|
|
||||||
image_to_crop >> from_image.In(kImageTag);
|
|
||||||
auto& image_cropping = graph.AddNode("ImageCroppingCalculator");
|
|
||||||
auto& image_cropping_opts =
|
|
||||||
image_cropping.GetOptions<ImageCroppingCalculatorOptions>();
|
|
||||||
image_cropping_opts.set_output_max_width(
|
|
||||||
image_to_tensor_options.output_tensor_width());
|
|
||||||
image_cropping_opts.set_output_max_height(
|
|
||||||
image_to_tensor_options.output_tensor_height());
|
|
||||||
norm_rect_no_rotation >> image_cropping.In(kNormRectTag);
|
|
||||||
auto& to_image = graph.AddNode("ToImageCalculator");
|
|
||||||
// ImageCroppingCalculator currently doesn't support mediapipe::Image, the
|
|
||||||
// graph selects its cpu or gpu path based on the image preprocessing
|
|
||||||
// backend.
|
|
||||||
if (use_gpu) {
|
|
||||||
from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag);
|
|
||||||
image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag);
|
|
||||||
} else {
|
|
||||||
from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag);
|
|
||||||
image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag);
|
|
||||||
}
|
|
||||||
|
|
||||||
return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(),
|
|
||||||
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
|
/*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -100,6 +100,7 @@ cc_library(
|
||||||
"//mediapipe/util:graph_builder_utils",
|
"//mediapipe/util:graph_builder_utils",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -90,7 +90,7 @@ NS_SWIFT_NAME(ClassificationResult)
|
||||||
* amount of data to process might exceed the maximum size that the model can process: to solve
|
* amount of data to process might exceed the maximum size that the model can process: to solve
|
||||||
* this, the input data is split into multiple chunks starting at different timestamps.
|
* this, the input data is split into multiple chunks starting at different timestamps.
|
||||||
*/
|
*/
|
||||||
@property(nonatomic, readonly) NSInteger timestampMs;
|
@property(nonatomic, readonly) NSInteger timestampInMilliseconds;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes a new `MPPClassificationResult` with the given array of classifications and time
|
* Initializes a new `MPPClassificationResult` with the given array of classifications and time
|
||||||
|
@ -98,14 +98,15 @@ NS_SWIFT_NAME(ClassificationResult)
|
||||||
*
|
*
|
||||||
* @param classifications An Array of `MPPClassifications` objects containing the predicted
|
* @param classifications An Array of `MPPClassifications` objects containing the predicted
|
||||||
* categories for each head of the model.
|
* categories for each head of the model.
|
||||||
* @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data
|
* @param timestampInMilliseconds The timestamp (in milliseconds) of the start of the chunk of data
|
||||||
* corresponding to these results.
|
* corresponding to these results.
|
||||||
*
|
*
|
||||||
* @return An instance of `MPPClassificationResult` initialized with the given array of
|
* @return An instance of `MPPClassificationResult` initialized with the given array of
|
||||||
* classifications and timestampMs.
|
* classifications and timestamp (in milliseconds).
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||||
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
|
NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -38,11 +38,11 @@
|
||||||
@implementation MPPClassificationResult
|
@implementation MPPClassificationResult
|
||||||
|
|
||||||
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
- (instancetype)initWithClassifications:(NSArray<MPPClassifications *> *)classifications
|
||||||
timestampMs:(NSInteger)timestampMs {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super init];
|
self = [super init];
|
||||||
if (self) {
|
if (self) {
|
||||||
_classifications = classifications;
|
_classifications = classifications;
|
||||||
_timestampMs = timestampMs;
|
_timestampInMilliseconds = timestampInMilliseconds;
|
||||||
}
|
}
|
||||||
|
|
||||||
return self;
|
return self;
|
||||||
|
|
|
@ -33,7 +33,7 @@ NS_SWIFT_NAME(EmbeddingResult)
|
||||||
* cases, the amount of data to process might exceed the maximum size that the model can process. To
|
* cases, the amount of data to process might exceed the maximum size that the model can process. To
|
||||||
* solve this, the input data is split into multiple chunks starting at different timestamps.
|
* solve this, the input data is split into multiple chunks starting at different timestamps.
|
||||||
*/
|
*/
|
||||||
@property(nonatomic, readonly) NSInteger timestampMs;
|
@property(nonatomic, readonly) NSInteger timestampInMilliseconds;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in
|
* Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in
|
||||||
|
@ -41,14 +41,14 @@ NS_SWIFT_NAME(EmbeddingResult)
|
||||||
*
|
*
|
||||||
* @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each
|
* @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each
|
||||||
* head of the model.
|
* head of the model.
|
||||||
* @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data
|
* @param timestampInMilliseconds The optional timestamp (in milliseconds) of the start of the chunk
|
||||||
* corresponding to these results. Pass `0` if timestamp is absent.
|
* of data corresponding to these results. Pass `0` if timestamp is absent.
|
||||||
*
|
*
|
||||||
* @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and
|
* @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and
|
||||||
* timestampMs.
|
* timestamp (in milliseconds).
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
|
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
|
||||||
timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -17,11 +17,11 @@
|
||||||
@implementation MPPEmbeddingResult
|
@implementation MPPEmbeddingResult
|
||||||
|
|
||||||
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
|
- (instancetype)initWithEmbeddings:(NSArray<MPPEmbedding *> *)embeddings
|
||||||
timestampMs:(NSInteger)timestampMs {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super init];
|
self = [super init];
|
||||||
if (self) {
|
if (self) {
|
||||||
_embeddings = embeddings;
|
_embeddings = embeddings;
|
||||||
_timestampMs = timestampMs;
|
_timestampInMilliseconds = timestampInMilliseconds;
|
||||||
}
|
}
|
||||||
|
|
||||||
return self;
|
return self;
|
||||||
|
|
|
@ -55,13 +55,13 @@ using ClassificationResultProto =
|
||||||
[classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]];
|
[classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]];
|
||||||
}
|
}
|
||||||
|
|
||||||
NSInteger timestampMs = 0;
|
NSInteger timestampInMilliseconds = 0;
|
||||||
if (classificationResultProto.has_timestamp_ms()) {
|
if (classificationResultProto.has_timestamp_ms()) {
|
||||||
timestampMs = (NSInteger)classificationResultProto.timestamp_ms();
|
timestampInMilliseconds = (NSInteger)classificationResultProto.timestamp_ms();
|
||||||
}
|
}
|
||||||
|
|
||||||
return [[MPPClassificationResult alloc] initWithClassifications:classifications
|
return [[MPPClassificationResult alloc] initWithClassifications:classifications
|
||||||
timestampMs:timestampMs];
|
timestampInMilliseconds:timestampInMilliseconds];
|
||||||
;
|
;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,12 +31,13 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto::
|
||||||
[embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]];
|
[embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]];
|
||||||
}
|
}
|
||||||
|
|
||||||
NSInteger timestampMs = 0;
|
NSInteger timestampInMilliseconds = 0;
|
||||||
if (embeddingResultProto.has_timestamp_ms()) {
|
if (embeddingResultProto.has_timestamp_ms()) {
|
||||||
timestampMs = (NSInteger)embeddingResultProto.timestamp_ms();
|
timestampInMilliseconds = (NSInteger)embeddingResultProto.timestamp_ms();
|
||||||
}
|
}
|
||||||
|
|
||||||
return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs];
|
return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings
|
||||||
|
timestampInMilliseconds:timestampInMilliseconds];
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -26,11 +26,12 @@ NS_SWIFT_NAME(TaskResult)
|
||||||
/**
|
/**
|
||||||
* Timestamp that is associated with the task result object.
|
* Timestamp that is associated with the task result object.
|
||||||
*/
|
*/
|
||||||
@property(nonatomic, assign, readonly) NSInteger timestampMs;
|
@property(nonatomic, assign, readonly) NSInteger timestampInMilliseconds;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER;
|
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
|
NS_DESIGNATED_INITIALIZER;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -16,16 +16,16 @@
|
||||||
|
|
||||||
@implementation MPPTaskResult
|
@implementation MPPTaskResult
|
||||||
|
|
||||||
- (instancetype)initWithTimestampMs:(NSInteger)timestampMs {
|
- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super init];
|
self = [super init];
|
||||||
if (self) {
|
if (self) {
|
||||||
_timestampMs = timestampMs;
|
_timestampInMilliseconds = timestampInMilliseconds;
|
||||||
}
|
}
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
- (id)copyWithZone:(NSZone *)zone {
|
- (id)copyWithZone:(NSZone *)zone {
|
||||||
return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs];
|
return [[MPPTaskResult alloc] initWithTimestampInMilliseconds:self.timestampInMilliseconds];
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -487,7 +487,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
|
||||||
NSError *liveStreamApiCallError;
|
NSError *liveStreamApiCallError;
|
||||||
XCTAssertFalse([imageClassifier classifyAsyncImage:image
|
XCTAssertFalse([imageClassifier classifyAsyncImage:image
|
||||||
timestampMs:0
|
timestampInMilliseconds:0
|
||||||
error:&liveStreamApiCallError]);
|
error:&liveStreamApiCallError]);
|
||||||
|
|
||||||
NSError *expectedLiveStreamApiCallError =
|
NSError *expectedLiveStreamApiCallError =
|
||||||
|
@ -501,7 +501,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
|
AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError);
|
||||||
|
|
||||||
NSError *videoApiCallError;
|
NSError *videoApiCallError;
|
||||||
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]);
|
XCTAssertFalse([imageClassifier classifyVideoFrame:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&videoApiCallError]);
|
||||||
|
|
||||||
NSError *expectedVideoApiCallError =
|
NSError *expectedVideoApiCallError =
|
||||||
[NSError errorWithDomain:kExpectedErrorDomain
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
@ -524,7 +526,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
|
||||||
NSError *liveStreamApiCallError;
|
NSError *liveStreamApiCallError;
|
||||||
XCTAssertFalse([imageClassifier classifyAsyncImage:image
|
XCTAssertFalse([imageClassifier classifyAsyncImage:image
|
||||||
timestampMs:0
|
timestampInMilliseconds:0
|
||||||
error:&liveStreamApiCallError]);
|
error:&liveStreamApiCallError]);
|
||||||
|
|
||||||
NSError *expectedLiveStreamApiCallError =
|
NSError *expectedLiveStreamApiCallError =
|
||||||
|
@ -575,7 +577,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
|
AssertEqualErrors(imageApiCallError, expectedImageApiCallError);
|
||||||
|
|
||||||
NSError *videoApiCallError;
|
NSError *videoApiCallError;
|
||||||
XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]);
|
XCTAssertFalse([imageClassifier classifyVideoFrame:image
|
||||||
|
timestampInMilliseconds:0
|
||||||
|
error:&videoApiCallError]);
|
||||||
|
|
||||||
NSError *expectedVideoApiCallError =
|
NSError *expectedVideoApiCallError =
|
||||||
[NSError errorWithDomain:kExpectedErrorDomain
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
@ -601,7 +605,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image
|
MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image
|
||||||
timestampMs:i
|
timestampInMilliseconds:i
|
||||||
error:nil];
|
error:nil];
|
||||||
[self assertImageClassifierResult:imageClassifierResult
|
[self assertImageClassifierResult:imageClassifierResult
|
||||||
hasExpectedCategoriesCount:maxResults
|
hasExpectedCategoriesCount:maxResults
|
||||||
|
@ -630,10 +634,10 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
|
|
||||||
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
|
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
|
||||||
|
|
||||||
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:1 error:nil]);
|
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]);
|
||||||
|
|
||||||
NSError *error;
|
NSError *error;
|
||||||
XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampMs:0 error:&error]);
|
XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampInMilliseconds:0 error:&error]);
|
||||||
|
|
||||||
NSError *expectedError =
|
NSError *expectedError =
|
||||||
[NSError errorWithDomain:kExpectedErrorDomain
|
[NSError errorWithDomain:kExpectedErrorDomain
|
||||||
|
@ -668,7 +672,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks";
|
||||||
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
|
MPPImage *image = [self imageWithFileInfo:kBurgerImage];
|
||||||
|
|
||||||
for (int i = 0; i < 3; i++) {
|
for (int i = 0; i < 3; i++) {
|
||||||
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:i error:nil]);
|
XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextClassifierResult)
|
||||||
*
|
*
|
||||||
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
|
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
|
||||||
* per classifier head.
|
* per classifier head.
|
||||||
* @param timestampMs The timestamp for this result.
|
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||||
*
|
*
|
||||||
* @return An instance of `MPPTextClassifierResult` initialized with the given
|
* @return An instance of `MPPTextClassifierResult` initialized with the given
|
||||||
* `MPPClassificationResult` and timestamp (in milliseconds).
|
* `MPPClassificationResult` and timestamp (in milliseconds).
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||||
timestampMs:(NSInteger)timestampMs;
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
@implementation MPPTextClassifierResult
|
@implementation MPPTextClassifierResult
|
||||||
|
|
||||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||||
timestampMs:(NSInteger)timestampMs {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super initWithTimestampMs:timestampMs];
|
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||||
if (self) {
|
if (self) {
|
||||||
_classificationResult = classificationResult;
|
_classificationResult = classificationResult;
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ using ::mediapipe::Packet;
|
||||||
|
|
||||||
return [[MPPTextClassifierResult alloc]
|
return [[MPPTextClassifierResult alloc]
|
||||||
initWithClassificationResult:classificationResult
|
initWithClassificationResult:classificationResult
|
||||||
timestampMs:(NSInteger)(packet.Timestamp().Value() /
|
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond)];
|
kMicroSecondsPerMilliSecond)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextEmbedderResult)
|
||||||
*
|
*
|
||||||
* @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results
|
* @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results
|
||||||
* per classifier head.
|
* per classifier head.
|
||||||
* @param timestampMs The timestamp for this result.
|
* @param timestampInMilliseconds The timestamp (in millisecondss) for this result.
|
||||||
*
|
*
|
||||||
* @return An instance of `MPPTextEmbedderResult` initialized with the given
|
* @return An instance of `MPPTextEmbedderResult` initialized with the given
|
||||||
* `MPPEmbeddingResult` and timestamp (in milliseconds).
|
* `MPPEmbeddingResult` and timestamp (in milliseconds).
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
|
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
|
||||||
timestampMs:(NSInteger)timestampMs;
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
@implementation MPPTextEmbedderResult
|
@implementation MPPTextEmbedderResult
|
||||||
|
|
||||||
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
|
- (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult
|
||||||
timestampMs:(NSInteger)timestampMs {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super initWithTimestampMs:timestampMs];
|
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||||
if (self) {
|
if (self) {
|
||||||
_embeddingResult = embeddingResult;
|
_embeddingResult = embeddingResult;
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ using ::mediapipe::Packet;
|
||||||
|
|
||||||
return [[MPPTextEmbedderResult alloc]
|
return [[MPPTextEmbedderResult alloc]
|
||||||
initWithEmbeddingResult:embeddingResult
|
initWithEmbeddingResult:embeddingResult
|
||||||
timestampMs:(NSInteger)(packet.Timestamp().Value() /
|
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond)];
|
kMicroSecondsPerMilliSecond)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -41,7 +41,7 @@
|
||||||
* timestamp.
|
* timestamp.
|
||||||
*
|
*
|
||||||
* @param image The image to send to the MediaPipe graph.
|
* @param image The image to send to the MediaPipe graph.
|
||||||
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
|
* @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
|
||||||
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
|
* @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no
|
||||||
* error will be saved.
|
* error will be saved.
|
||||||
*
|
*
|
||||||
|
@ -49,7 +49,7 @@
|
||||||
* occurred during the conversion.
|
* occurred during the conversion.
|
||||||
*/
|
*/
|
||||||
+ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image
|
+ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error;
|
error:(NSError **)error;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -66,11 +66,11 @@
|
||||||
* specified timestamp.
|
* specified timestamp.
|
||||||
*
|
*
|
||||||
* @param image The `NormalizedRect` to send to the MediaPipe graph.
|
* @param image The `NormalizedRect` to send to the MediaPipe graph.
|
||||||
* @param timestampMs The timestamp (in milliseconds) to assign to the packet.
|
* @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet.
|
||||||
*
|
*
|
||||||
* @return The MediaPipe packet containing the normalized rect.
|
* @return The MediaPipe packet containing the normalized rect.
|
||||||
*/
|
*/
|
||||||
+ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect
|
+ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect
|
||||||
timestampMs:(NSInteger)timestampMs;
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -42,7 +42,7 @@ using ::mediapipe::Timestamp;
|
||||||
}
|
}
|
||||||
|
|
||||||
+ (Packet)createPacketWithMPPImage:(MPPImage *)image
|
+ (Packet)createPacketWithMPPImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::unique_ptr<ImageFrame> imageFrame = [image imageFrameWithError:error];
|
std::unique_ptr<ImageFrame> imageFrame = [image imageFrameWithError:error];
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ using ::mediapipe::Timestamp;
|
||||||
}
|
}
|
||||||
|
|
||||||
return MakePacket<Image>(std::move(imageFrame))
|
return MakePacket<Image>(std::move(imageFrame))
|
||||||
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
|
.At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
|
||||||
}
|
}
|
||||||
|
|
||||||
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect {
|
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect {
|
||||||
|
@ -59,9 +59,9 @@ using ::mediapipe::Timestamp;
|
||||||
}
|
}
|
||||||
|
|
||||||
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect
|
+ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect
|
||||||
timestampMs:(NSInteger)timestampMs {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
return MakePacket<NormalizedRect>(std::move(normalizedRect))
|
return MakePacket<NormalizedRect>(std::move(normalizedRect))
|
||||||
.At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond)));
|
.At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
|
handedness:(NSArray<NSArray<MPPCategory *> *> *)handedness
|
||||||
gestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
|
gestures:(NSArray<NSArray<MPPCategory *> *> *)gestures
|
||||||
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super initWithTimestampMs:timestampInMilliseconds];
|
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||||
if (self) {
|
if (self) {
|
||||||
_landmarks = landmarks;
|
_landmarks = landmarks;
|
||||||
_worldLandmarks = worldLandmarks;
|
_worldLandmarks = worldLandmarks;
|
||||||
|
|
|
@ -122,17 +122,17 @@ NS_SWIFT_NAME(ImageClassifier)
|
||||||
* `MPPRunningModeVideo`.
|
* `MPPRunningModeVideo`.
|
||||||
*
|
*
|
||||||
* @param image The `MPPImage` on which image classification is to be performed.
|
* @param image The `MPPImage` on which image classification is to be performed.
|
||||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||||
* monotonically increasing.
|
* timestamps must be monotonically increasing.
|
||||||
* @param error An optional error parameter populated when there is an error in performing image
|
* @param error An optional error parameter populated when there is an error in performing image
|
||||||
* classification on the input video frame.
|
* classification on the input video frame.
|
||||||
*
|
*
|
||||||
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
|
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
|
||||||
*/
|
*/
|
||||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(classify(videoFrame:timestampMs:));
|
NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs image classification on the provided video frame of type `MPPImage` cropped to the
|
* Performs image classification on the provided video frame of type `MPPImage` cropped to the
|
||||||
|
@ -145,8 +145,8 @@ NS_SWIFT_NAME(ImageClassifier)
|
||||||
*
|
*
|
||||||
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
||||||
* performed.
|
* performed.
|
||||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||||
* monotonically increasing.
|
* timestamps must be monotonically increasing.
|
||||||
* @param roi A `CGRect` specifying the region of interest within the video frame of type
|
* @param roi A `CGRect` specifying the region of interest within the video frame of type
|
||||||
* `MPPImage`, on which image classification should be performed.
|
* `MPPImage`, on which image classification should be performed.
|
||||||
* @param error An optional error parameter populated when there is an error in performing image
|
* @param error An optional error parameter populated when there is an error in performing image
|
||||||
|
@ -155,10 +155,10 @@ NS_SWIFT_NAME(ImageClassifier)
|
||||||
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
|
* @return An `MPPImageClassifierResult` object that contains a list of image classifications.
|
||||||
*/
|
*/
|
||||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:));
|
NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:regionOfInterest:));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends live stream image data of type `MPPImage` to perform image classification using the whole
|
* Sends live stream image data of type `MPPImage` to perform image classification using the whole
|
||||||
|
@ -172,16 +172,17 @@ NS_SWIFT_NAME(ImageClassifier)
|
||||||
*
|
*
|
||||||
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
||||||
* performed.
|
* performed.
|
||||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||||
* to the image classifier. The input timestamps must be monotonically increasing.
|
* image is sent to the image classifier. The input timestamps must be monotonically increasing.
|
||||||
* @param error An optional error parameter populated when there is an error in performing image
|
* @param error An optional error parameter populated when there is an error in performing image
|
||||||
* classification on the input live stream image data.
|
* classification on the input live stream image data.
|
||||||
*
|
*
|
||||||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||||
*/
|
*/
|
||||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:));
|
error:(NSError **)error
|
||||||
|
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends live stream image data of type `MPPImage` to perform image classification, cropped to the
|
* Sends live stream image data of type `MPPImage` to perform image classification, cropped to the
|
||||||
|
@ -195,8 +196,8 @@ NS_SWIFT_NAME(ImageClassifier)
|
||||||
*
|
*
|
||||||
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
* @param image A live stream image data of type `MPPImage` on which image classification is to be
|
||||||
* performed.
|
* performed.
|
||||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||||
* to the image classifier. The input timestamps must be monotonically increasing.
|
* image is sent to the image classifier. The input timestamps must be monotonically increasing.
|
||||||
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
|
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
|
||||||
* of type `MPPImage`, on which image classification should be performed.
|
* of type `MPPImage`, on which image classification should be performed.
|
||||||
* @param error An optional error parameter populated when there is an error in performing image
|
* @param error An optional error parameter populated when there is an error in performing image
|
||||||
|
@ -205,10 +206,10 @@ NS_SWIFT_NAME(ImageClassifier)
|
||||||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||||
*/
|
*/
|
||||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:));
|
NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:regionOfInterest:));
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -149,7 +149,7 @@ static NSString *const kTaskGraphName =
|
||||||
}
|
}
|
||||||
|
|
||||||
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
|
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<NormalizedRect> rect =
|
std::optional<NormalizedRect> rect =
|
||||||
|
@ -162,14 +162,15 @@ static NSString *const kTaskGraphName =
|
||||||
}
|
}
|
||||||
|
|
||||||
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
|
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
error:error];
|
error:error];
|
||||||
if (imagePacket.IsEmpty()) {
|
if (imagePacket.IsEmpty()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
Packet normalizedRectPacket =
|
||||||
timestampMs:timestampMs];
|
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
||||||
|
timestampInMilliseconds:timestampInMilliseconds];
|
||||||
|
|
||||||
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
|
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
|
||||||
return inputPacketMap;
|
return inputPacketMap;
|
||||||
|
@ -180,11 +181,11 @@ static NSString *const kTaskGraphName =
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:roi
|
regionOfInterest:roi
|
||||||
error:error];
|
error:error];
|
||||||
if (!inputPacketMap.has_value()) {
|
if (!inputPacketMap.has_value()) {
|
||||||
|
@ -204,20 +205,20 @@ static NSString *const kTaskGraphName =
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
return [self classifyVideoFrame:image
|
return [self classifyVideoFrame:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:CGRectZero
|
regionOfInterest:CGRectZero
|
||||||
error:error];
|
error:error];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:roi
|
regionOfInterest:roi
|
||||||
error:error];
|
error:error];
|
||||||
if (!inputPacketMap.has_value()) {
|
if (!inputPacketMap.has_value()) {
|
||||||
|
@ -228,10 +229,10 @@ static NSString *const kTaskGraphName =
|
||||||
}
|
}
|
||||||
|
|
||||||
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
- (BOOL)classifyAsyncImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
return [self classifyAsyncImage:image
|
return [self classifyAsyncImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:CGRectZero
|
regionOfInterest:CGRectZero
|
||||||
error:error];
|
error:error];
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,13 +31,13 @@ NS_SWIFT_NAME(ImageClassifierResult)
|
||||||
*
|
*
|
||||||
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
|
* @param classificationResult The `MPPClassificationResult` instance containing one set of results
|
||||||
* per classifier head.
|
* per classifier head.
|
||||||
* @param timestampMs The timestamp for this result.
|
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||||
*
|
*
|
||||||
* @return An instance of `MPPImageClassifierResult` initialized with the given
|
* @return An instance of `MPPImageClassifierResult` initialized with the given
|
||||||
* `MPPClassificationResult` and timestamp (in milliseconds).
|
* `MPPClassificationResult` and timestamp (in milliseconds).
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||||
timestampMs:(NSInteger)timestampMs;
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
@implementation MPPImageClassifierResult
|
@implementation MPPImageClassifierResult
|
||||||
|
|
||||||
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
- (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult
|
||||||
timestampMs:(NSInteger)timestampMs {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super initWithTimestampMs:timestampMs];
|
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||||
if (self) {
|
if (self) {
|
||||||
_classificationResult = classificationResult;
|
_classificationResult = classificationResult;
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ using ::mediapipe::Packet;
|
||||||
|
|
||||||
return [[MPPImageClassifierResult alloc]
|
return [[MPPImageClassifierResult alloc]
|
||||||
initWithClassificationResult:classificationResult
|
initWithClassificationResult:classificationResult
|
||||||
timestampMs:(NSInteger)(packet.Timestamp().Value() /
|
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||||
kMicroSecondsPerMilliSecond)];
|
kMicroSecondsPerMilliSecond)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,13 +36,13 @@ NS_SWIFT_NAME(ObjectDetectionResult)
|
||||||
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
|
* @param detections An array of `MPPDetection` objects each of which has a bounding box that is
|
||||||
* expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width)
|
* expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width)
|
||||||
* x [0,image_height)`, which are the dimensions of the underlying image data.
|
* x [0,image_height)`, which are the dimensions of the underlying image data.
|
||||||
* @param timestampMs The timestamp for this result.
|
* @param timestampInMilliseconds The timestamp (in milliseconds) for this result.
|
||||||
*
|
*
|
||||||
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
|
* @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections
|
||||||
* and timestamp (in milliseconds).
|
* and timestamp (in milliseconds).
|
||||||
*/
|
*/
|
||||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
||||||
timestampMs:(NSInteger)timestampMs;
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds;
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
@implementation MPPObjectDetectionResult
|
@implementation MPPObjectDetectionResult
|
||||||
|
|
||||||
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
- (instancetype)initWithDetections:(NSArray<MPPDetection *> *)detections
|
||||||
timestampMs:(NSInteger)timestampMs {
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds {
|
||||||
self = [super initWithTimestampMs:timestampMs];
|
self = [super initWithTimestampInMilliseconds:timestampInMilliseconds];
|
||||||
if (self) {
|
if (self) {
|
||||||
_detections = detections;
|
_detections = detections;
|
||||||
}
|
}
|
||||||
|
|
|
@ -138,8 +138,8 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
* `MPPRunningModeVideo`.
|
* `MPPRunningModeVideo`.
|
||||||
*
|
*
|
||||||
* @param image The `MPPImage` on which object detection is to be performed.
|
* @param image The `MPPImage` on which object detection is to be performed.
|
||||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||||
* monotonically increasing.
|
* timestamps must be monotonically increasing.
|
||||||
* @param error An optional error parameter populated when there is an error in performing object
|
* @param error An optional error parameter populated when there is an error in performing object
|
||||||
* detection on the input image.
|
* detection on the input image.
|
||||||
*
|
*
|
||||||
|
@ -149,9 +149,9 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
* image data.
|
* image data.
|
||||||
*/
|
*/
|
||||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(detect(videoFrame:timestampMs:));
|
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs object detection on the provided video frame of type `MPPImage` cropped to the
|
* Performs object detection on the provided video frame of type `MPPImage` cropped to the
|
||||||
|
@ -164,8 +164,8 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
*
|
*
|
||||||
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
||||||
* performed.
|
* performed.
|
||||||
* @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be
|
* @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input
|
||||||
* monotonically increasing.
|
* timestamps must be monotonically increasing.
|
||||||
* @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which
|
* @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which
|
||||||
* object detection should be performed.
|
* object detection should be performed.
|
||||||
*
|
*
|
||||||
|
@ -178,10 +178,10 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
* image data.
|
* image data.
|
||||||
*/
|
*/
|
||||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(detect(videoFrame:timestampMs:regionOfInterest:));
|
NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:regionOfInterest:));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends live stream image data of type `MPPImage` to perform object detection using the whole
|
* Sends live stream image data of type `MPPImage` to perform object detection using the whole
|
||||||
|
@ -195,16 +195,17 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
*
|
*
|
||||||
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
||||||
* performed.
|
* performed.
|
||||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||||
* to the object detector. The input timestamps must be monotonically increasing.
|
* image is sent to the object detector. The input timestamps must be monotonically increasing.
|
||||||
* @param error An optional error parameter populated when there is an error in performing object
|
* @param error An optional error parameter populated when there is an error in performing object
|
||||||
* detection on the input live stream image data.
|
* detection on the input live stream image data.
|
||||||
*
|
*
|
||||||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||||
*/
|
*/
|
||||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error NS_SWIFT_NAME(detectAsync(image:timestampMs:));
|
error:(NSError **)error
|
||||||
|
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sends live stream image data of type `MPPImage` to perform object detection, cropped to the
|
* Sends live stream image data of type `MPPImage` to perform object detection, cropped to the
|
||||||
|
@ -218,8 +219,8 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
*
|
*
|
||||||
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
* @param image A live stream image data of type `MPPImage` on which object detection is to be
|
||||||
* performed.
|
* performed.
|
||||||
* @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent
|
* @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input
|
||||||
* to the object detector. The input timestamps must be monotonically increasing.
|
* image is sent to the object detector. The input timestamps must be monotonically increasing.
|
||||||
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
|
* @param roi A `CGRect` specifying the region of interest within the given live stream image data
|
||||||
* of type `MPPImage`, on which iobject detection should be performed.
|
* of type `MPPImage`, on which iobject detection should be performed.
|
||||||
* @param error An optional error parameter populated when there is an error in performing object
|
* @param error An optional error parameter populated when there is an error in performing object
|
||||||
|
@ -228,10 +229,10 @@ NS_SWIFT_NAME(ObjectDetector)
|
||||||
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
* @return `YES` if the image was sent to the task successfully, otherwise `NO`.
|
||||||
*/
|
*/
|
||||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error
|
error:(NSError **)error
|
||||||
NS_SWIFT_NAME(detectAsync(image:timestampMs:regionOfInterest:));
|
NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:regionOfInterest:));
|
||||||
|
|
||||||
- (instancetype)init NS_UNAVAILABLE;
|
- (instancetype)init NS_UNAVAILABLE;
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
||||||
}
|
}
|
||||||
|
|
||||||
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
|
- (std::optional<PacketMap>)inputPacketMapWithMPPImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<NormalizedRect> rect =
|
std::optional<NormalizedRect> rect =
|
||||||
|
@ -170,14 +170,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
||||||
}
|
}
|
||||||
|
|
||||||
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
|
Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
error:error];
|
error:error];
|
||||||
if (imagePacket.IsEmpty()) {
|
if (imagePacket.IsEmpty()) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
Packet normalizedRectPacket =
|
||||||
timestampMs:timestampMs];
|
[MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()
|
||||||
|
timestampInMilliseconds:timestampInMilliseconds];
|
||||||
|
|
||||||
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
|
PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket);
|
||||||
return inputPacketMap;
|
return inputPacketMap;
|
||||||
|
@ -188,11 +189,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:roi
|
regionOfInterest:roi
|
||||||
error:error];
|
error:error];
|
||||||
if (!inputPacketMap.has_value()) {
|
if (!inputPacketMap.has_value()) {
|
||||||
|
@ -212,20 +213,20 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
||||||
}
|
}
|
||||||
|
|
||||||
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
- (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
return [self detectInVideoFrame:image
|
return [self detectInVideoFrame:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:CGRectZero
|
regionOfInterest:CGRectZero
|
||||||
error:error];
|
error:error];
|
||||||
}
|
}
|
||||||
|
|
||||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
regionOfInterest:(CGRect)roi
|
regionOfInterest:(CGRect)roi
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
std::optional<PacketMap> inputPacketMap = [self inputPacketMapWithMPPImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:roi
|
regionOfInterest:roi
|
||||||
error:error];
|
error:error];
|
||||||
if (!inputPacketMap.has_value()) {
|
if (!inputPacketMap.has_value()) {
|
||||||
|
@ -236,10 +237,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG
|
||||||
}
|
}
|
||||||
|
|
||||||
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
- (BOOL)detectAsyncInImage:(MPPImage *)image
|
||||||
timestampMs:(NSInteger)timestampMs
|
timestampInMilliseconds:(NSInteger)timestampInMilliseconds
|
||||||
error:(NSError **)error {
|
error:(NSError **)error {
|
||||||
return [self detectAsyncInImage:image
|
return [self detectAsyncInImage:image
|
||||||
timestampMs:timestampMs
|
timestampInMilliseconds:timestampInMilliseconds
|
||||||
regionOfInterest:CGRectZero
|
regionOfInterest:CGRectZero
|
||||||
error:error];
|
error:error];
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,8 +38,9 @@ using ::mediapipe::Packet;
|
||||||
}
|
}
|
||||||
|
|
||||||
return [[MPPObjectDetectionResult alloc]
|
return [[MPPObjectDetectionResult alloc]
|
||||||
initWithDetections:detections
|
initWithDetections:detections
|
||||||
timestampMs:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)];
|
timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() /
|
||||||
|
kMicroSecondsPerMilliSecond)];
|
||||||
}
|
}
|
||||||
|
|
||||||
@end
|
@end
|
||||||
|
|
|
@ -198,9 +198,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* size of the stylized output is based the model output size and can be smaller than the input
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* image.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created
|
* @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created
|
||||||
|
@ -220,9 +220,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* the stylized output image size is the smaller of the model output size and the size of the
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
@ -256,9 +256,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* size of the stylized output is based the model output size and can be smaller than the input
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* image.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
* @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a
|
||||||
|
@ -281,9 +281,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* the stylized output image size is the smaller of the model output size and the size of the
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
@ -320,9 +320,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* size of the stylized output is based the model output size and can be smaller than the input
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* image.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param timestampMs the input timestamp (in milliseconds).
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
@ -346,9 +346,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* the stylized output image size is the smaller of the model output size and the size of the
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
* is detected on the input image, returns {@code Optional.empty()}. *
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
@ -387,9 +387,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* size of the stylized output is based the model output size and can be smaller than the input
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* image.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param timestampMs the input timestamp (in milliseconds).
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
@ -414,9 +414,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* the stylized output image size is the smaller of the model output size and the size of the
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param timestampMs the input timestamp (in milliseconds).
|
* @param timestampMs the input timestamp (in milliseconds).
|
||||||
|
@ -445,9 +445,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
*
|
*
|
||||||
* <p>{@link FaceStylizer} supports the following color space types:
|
* <p>{@link FaceStylizer} supports the following color space types:
|
||||||
*
|
*
|
||||||
* <p>The image can be of any size. To ensure that the output image has reasonable quality, the
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* size of the stylized output is based the model output * size and can be smaller than the input
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* image.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
|
@ -475,9 +475,9 @@ public final class FaceStylizer extends BaseVisionTaskApi {
|
||||||
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
* <li>{@link android.graphics.Bitmap.Config#ARGB_8888}
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <p>The input image can be of any size. To ensure that the output image has reasonable quality,
|
* <p>The input image can be of any size. The output image is the stylized image with the most
|
||||||
* the stylized output image size is the smaller of the model output size and the size of the
|
* visible face. The stylized output image size is the same as the model output size. When no face
|
||||||
* {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}.
|
* is detected on the input image, returns {@code Optional.empty()}.
|
||||||
*
|
*
|
||||||
* @param image a MediaPipe {@link MPImage} object for processing.
|
* @param image a MediaPipe {@link MPImage} object for processing.
|
||||||
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
* @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the
|
||||||
|
|
|
@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
"IMAGE:" + IMAGE_IN_STREAM_NAME,
|
"IMAGE:" + IMAGE_IN_STREAM_NAME,
|
||||||
"ROI:" + ROI_IN_STREAM_NAME,
|
"ROI:" + ROI_IN_STREAM_NAME,
|
||||||
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
"NORM_RECT:" + NORM_RECT_IN_STREAM_NAME));
|
||||||
private static final List<String> OUTPUT_STREAMS =
|
private static final int IMAGE_OUT_STREAM_INDEX = 0;
|
||||||
Collections.unmodifiableList(
|
|
||||||
Arrays.asList(
|
|
||||||
"GROUPED_SEGMENTATION:segmented_mask_out",
|
|
||||||
"IMAGE:image_out",
|
|
||||||
"SEGMENTATION:0:segmentation"));
|
|
||||||
private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0;
|
|
||||||
private static final int IMAGE_OUT_STREAM_INDEX = 1;
|
|
||||||
private static final int SEGMENTATION_OUT_STREAM_INDEX = 2;
|
|
||||||
private static final String TASK_GRAPH_NAME =
|
private static final String TASK_GRAPH_NAME =
|
||||||
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
"mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph";
|
||||||
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
|
@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
*/
|
*/
|
||||||
public static InteractiveSegmenter createFromOptions(
|
public static InteractiveSegmenter createFromOptions(
|
||||||
Context context, InteractiveSegmenterOptions segmenterOptions) {
|
Context context, InteractiveSegmenterOptions segmenterOptions) {
|
||||||
|
if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set.");
|
||||||
|
}
|
||||||
|
List<String> outputStreams = new ArrayList<>();
|
||||||
|
outputStreams.add("IMAGE:image_out");
|
||||||
|
if (segmenterOptions.outputConfidenceMasks()) {
|
||||||
|
outputStreams.add("CONFIDENCE_MASKS:confidence_masks");
|
||||||
|
}
|
||||||
|
final int confidenceMasksOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
if (segmenterOptions.outputCategoryMask()) {
|
||||||
|
outputStreams.add("CATEGORY_MASK:category_mask");
|
||||||
|
}
|
||||||
|
final int categoryMaskOutStreamIndex = outputStreams.size() - 1;
|
||||||
|
|
||||||
// TODO: Consolidate OutputHandler and TaskRunner.
|
// TODO: Consolidate OutputHandler and TaskRunner.
|
||||||
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
OutputHandler<ImageSegmenterResult, MPImage> handler = new OutputHandler<>();
|
||||||
handler.setOutputPacketConverter(
|
handler.setOutputPacketConverter(
|
||||||
|
@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
@Override
|
@Override
|
||||||
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
public ImageSegmenterResult convertToTaskResult(List<Packet> packets)
|
||||||
throws MediaPipeException {
|
throws MediaPipeException {
|
||||||
if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) {
|
if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) {
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
Optional.empty(),
|
Optional.empty(),
|
||||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp());
|
packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp());
|
||||||
}
|
}
|
||||||
List<MPImage> segmentedMasks = new ArrayList<>();
|
// If resultListener is not provided, the resulted MPImage is deep copied from
|
||||||
int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
// mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet
|
||||||
int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX));
|
// memory.
|
||||||
int imageFormat =
|
boolean copyImage = !segmenterOptions.resultListener().isPresent();
|
||||||
segmenterOptions.outputType()
|
Optional<List<MPImage>> confidenceMasks = Optional.empty();
|
||||||
== InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK
|
if (segmenterOptions.outputConfidenceMasks()) {
|
||||||
? MPImage.IMAGE_FORMAT_VEC32F1
|
confidenceMasks = Optional.of(new ArrayList<>());
|
||||||
: MPImage.IMAGE_FORMAT_ALPHA;
|
int width =
|
||||||
int imageListSize =
|
PacketGetter.getImageWidthFromImageList(
|
||||||
PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX));
|
packets.get(confidenceMasksOutStreamIndex));
|
||||||
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
int height =
|
||||||
// If resultListener is not provided, the resulted MPImage is deep copied from mediapipe
|
PacketGetter.getImageHeightFromImageList(
|
||||||
// graph. If provided, the result MPImage is wrapping the mediapipe packet memory.
|
packets.get(confidenceMasksOutStreamIndex));
|
||||||
if (!segmenterOptions.resultListener().isPresent()) {
|
int imageListSize =
|
||||||
for (int i = 0; i < imageListSize; i++) {
|
PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex));
|
||||||
buffersArray[i] =
|
ByteBuffer[] buffersArray = new ByteBuffer[imageListSize];
|
||||||
ByteBuffer.allocateDirect(
|
// confidence masks are float type image.
|
||||||
width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1));
|
final int numBytes = 4;
|
||||||
|
if (copyImage) {
|
||||||
|
for (int i = 0; i < imageListSize; i++) {
|
||||||
|
buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!PacketGetter.getImageList(
|
||||||
|
packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) {
|
||||||
|
throw new MediaPipeException(
|
||||||
|
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||||
|
"There is an error getting confidence masks.");
|
||||||
|
}
|
||||||
|
for (ByteBuffer buffer : buffersArray) {
|
||||||
|
ByteBufferImageBuilder builder =
|
||||||
|
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1);
|
||||||
|
confidenceMasks.get().add(builder.build());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!PacketGetter.getImageList(
|
Optional<MPImage> categoryMask = Optional.empty();
|
||||||
packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX),
|
if (segmenterOptions.outputCategoryMask()) {
|
||||||
buffersArray,
|
int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex));
|
||||||
!segmenterOptions.resultListener().isPresent())) {
|
int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex));
|
||||||
throw new MediaPipeException(
|
ByteBuffer buffer;
|
||||||
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
if (copyImage) {
|
||||||
"There is an error getting segmented masks. It usually results from incorrect"
|
buffer = ByteBuffer.allocateDirect(width * height);
|
||||||
+ " options of unsupported OutputType of given model.");
|
if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) {
|
||||||
}
|
throw new MediaPipeException(
|
||||||
for (ByteBuffer buffer : buffersArray) {
|
MediaPipeException.StatusCode.INTERNAL.ordinal(),
|
||||||
|
"There is an error getting category mask.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex));
|
||||||
|
}
|
||||||
ByteBufferImageBuilder builder =
|
ByteBufferImageBuilder builder =
|
||||||
new ByteBufferImageBuilder(buffer, width, height, imageFormat);
|
new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA);
|
||||||
segmentedMasks.add(builder.build());
|
categoryMask = Optional.of(builder.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
return ImageSegmenterResult.create(
|
return ImageSegmenterResult.create(
|
||||||
Optional.of(segmentedMasks),
|
confidenceMasks,
|
||||||
Optional.empty(),
|
categoryMask,
|
||||||
BaseVisionTaskApi.generateResultTimestampMs(
|
BaseVisionTaskApi.generateResultTimestampMs(
|
||||||
RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)));
|
RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX)));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
.setTaskRunningModeName(RunningMode.IMAGE.name())
|
||||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||||
.setInputStreams(INPUT_STREAMS)
|
.setInputStreams(INPUT_STREAMS)
|
||||||
.setOutputStreams(OUTPUT_STREAMS)
|
.setOutputStreams(outputStreams)
|
||||||
.setTaskOptions(segmenterOptions)
|
.setTaskOptions(segmenterOptions)
|
||||||
.setEnableFlowLimiting(false)
|
.setEnableFlowLimiting(false)
|
||||||
.build(),
|
.build(),
|
||||||
|
@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
/** Sets the base options for the image segmenter task. */
|
/** Sets the base options for the image segmenter task. */
|
||||||
public abstract Builder setBaseOptions(BaseOptions value);
|
public abstract Builder setBaseOptions(BaseOptions value);
|
||||||
|
|
||||||
/** The output type from image segmenter. */
|
/** Sets whether to output confidence masks. Default to true. */
|
||||||
public abstract Builder setOutputType(OutputType value);
|
public abstract Builder setOutputConfidenceMasks(boolean value);
|
||||||
|
|
||||||
|
/** Sets whether to output category mask. Default to false. */
|
||||||
|
public abstract Builder setOutputCategoryMask(boolean value);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
* Sets an optional {@link ResultListener} to receive the segmentation results when the graph
|
||||||
|
@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
|
|
||||||
abstract BaseOptions baseOptions();
|
abstract BaseOptions baseOptions();
|
||||||
|
|
||||||
abstract OutputType outputType();
|
abstract boolean outputConfidenceMasks();
|
||||||
|
|
||||||
|
abstract boolean outputCategoryMask();
|
||||||
|
|
||||||
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
|
abstract Optional<ResultListener<ImageSegmenterResult, MPImage>> resultListener();
|
||||||
|
|
||||||
abstract Optional<ErrorListener> errorListener();
|
abstract Optional<ErrorListener> errorListener();
|
||||||
|
|
||||||
/** The output type of segmentation results. */
|
|
||||||
public enum OutputType {
|
|
||||||
// Gives a single output mask where each pixel represents the class which
|
|
||||||
// the pixel in the original image was predicted to belong to.
|
|
||||||
CATEGORY_MASK,
|
|
||||||
// Gives a list of output masks where, for each mask, each pixel represents
|
|
||||||
// the prediction confidence, usually in the [0, 1] range.
|
|
||||||
CONFIDENCE_MASK
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Builder builder() {
|
public static Builder builder() {
|
||||||
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
|
return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder()
|
||||||
.setOutputType(OutputType.CATEGORY_MASK);
|
.setOutputConfidenceMasks(true)
|
||||||
|
.setOutputCategoryMask(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi {
|
||||||
|
|
||||||
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder =
|
||||||
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
SegmenterOptionsProto.SegmenterOptions.newBuilder();
|
||||||
if (outputType() == OutputType.CONFIDENCE_MASK) {
|
|
||||||
segmenterOptionsBuilder.setOutputType(
|
|
||||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK);
|
|
||||||
} else if (outputType() == OutputType.CATEGORY_MASK) {
|
|
||||||
segmenterOptionsBuilder.setOutputType(
|
|
||||||
SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK);
|
|
||||||
}
|
|
||||||
|
|
||||||
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder);
|
||||||
return CalculatorOptions.newBuilder()
|
return CalculatorOptions.newBuilder()
|
||||||
.setExtension(
|
.setExtension(
|
||||||
|
|
|
@ -234,8 +234,8 @@ public class FaceStylizerTest {
|
||||||
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage);
|
FaceStylizerResult actualResult = faceStylizer.stylize(inputImage);
|
||||||
MPImage stylizedImage = actualResult.stylizedImage().get();
|
MPImage stylizedImage = actualResult.stylizedImage().get();
|
||||||
assertThat(stylizedImage).isNotNull();
|
assertThat(stylizedImage).isNotNull();
|
||||||
assertThat(stylizedImage.getWidth()).isEqualTo(83);
|
assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize);
|
||||||
assertThat(stylizedImage.getHeight()).isEqualTo(83);
|
assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -53,18 +53,15 @@ public class InteractiveSegmenterTest {
|
||||||
InteractiveSegmenterOptions options =
|
InteractiveSegmenterOptions options =
|
||||||
InteractiveSegmenterOptions.builder()
|
InteractiveSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK)
|
.setOutputConfidenceMasks(false)
|
||||||
|
.setOutputCategoryMask(true)
|
||||||
.build();
|
.build();
|
||||||
InteractiveSegmenter imageSegmenter =
|
InteractiveSegmenter imageSegmenter =
|
||||||
InteractiveSegmenter.createFromOptions(
|
InteractiveSegmenter.createFromOptions(
|
||||||
ApplicationProvider.getApplicationContext(), options);
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
MPImage image = getImageFromAsset(inputImageName);
|
MPImage image = getImageFromAsset(inputImageName);
|
||||||
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi);
|
||||||
// TODO update to correct category mask output.
|
assertThat(actualResult.categoryMask().isPresent()).isTrue();
|
||||||
// After InteractiveSegmenter updated according to (b/276519300), update this to use
|
|
||||||
// categoryMask field instead of confidenceMasks.
|
|
||||||
List<MPImage> segmentations = actualResult.confidenceMasks().get();
|
|
||||||
assertThat(segmentations.size()).isEqualTo(1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -75,15 +72,17 @@ public class InteractiveSegmenterTest {
|
||||||
InteractiveSegmenterOptions options =
|
InteractiveSegmenterOptions options =
|
||||||
InteractiveSegmenterOptions.builder()
|
InteractiveSegmenterOptions.builder()
|
||||||
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
.setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build())
|
||||||
.setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK)
|
.setOutputConfidenceMasks(true)
|
||||||
|
.setOutputCategoryMask(false)
|
||||||
.build();
|
.build();
|
||||||
InteractiveSegmenter imageSegmenter =
|
InteractiveSegmenter imageSegmenter =
|
||||||
InteractiveSegmenter.createFromOptions(
|
InteractiveSegmenter.createFromOptions(
|
||||||
ApplicationProvider.getApplicationContext(), options);
|
ApplicationProvider.getApplicationContext(), options);
|
||||||
ImageSegmenterResult actualResult =
|
ImageSegmenterResult actualResult =
|
||||||
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
imageSegmenter.segment(getImageFromAsset(inputImageName), roi);
|
||||||
List<MPImage> segmentations = actualResult.confidenceMasks().get();
|
assertThat(actualResult.confidenceMasks().isPresent()).isTrue();
|
||||||
assertThat(segmentations.size()).isEqualTo(2);
|
List<MPImage> confidenceMasks = actualResult.confidenceMasks().get();
|
||||||
|
assertThat(confidenceMasks.size()).isEqualTo(2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -204,6 +204,11 @@ This can be useful for resetting a stateful task graph to process new data.
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: The underlying medipaipe graph fails to reset and restart.
|
RuntimeError: The underlying medipaipe graph fails to reset and restart.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
task_runner.def(
|
||||||
|
"get_graph_config",
|
||||||
|
[](TaskRunner* self) { return self->GetGraphConfig(); },
|
||||||
|
R"doc(Returns the canonicalized CalculatorGraphConfig of the underlying graph.)doc");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace python
|
} // namespace python
|
||||||
|
|
|
@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions
|
||||||
|
|
||||||
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
|
_BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite'
|
||||||
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
|
_REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite'
|
||||||
|
_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite'
|
||||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/text'
|
||||||
# Tolerance for embedding vector coordinate values.
|
# Tolerance for embedding vector coordinate values.
|
||||||
_EPSILON = 1e-4
|
_EPSILON = 1e-4
|
||||||
|
@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase):
|
||||||
16,
|
16,
|
||||||
(0.549632, 0.552879),
|
(0.549632, 0.552879),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_NAME,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(1.422951, 1.404664),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_CONTENT,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(0.127049, 0.125416),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
|
def test_embed(self, l2_normalize, quantize, model_name, model_file_type,
|
||||||
expected_similarity, expected_size, expected_first_values):
|
expected_similarity, expected_size, expected_first_values):
|
||||||
|
@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase):
|
||||||
16,
|
16,
|
||||||
(0.549632, 0.552879),
|
(0.549632, 0.552879),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_NAME,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(1.422951, 1.404664),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
_USE_MODEL_FILE,
|
||||||
|
ModelFileType.FILE_CONTENT,
|
||||||
|
0.851961,
|
||||||
|
100,
|
||||||
|
(0.127049, 0.125416),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
def test_embed_in_context(self, l2_normalize, quantize, model_name,
|
def test_embed_in_context(self, l2_normalize, quantize, model_name,
|
||||||
model_file_type, expected_similarity, expected_size,
|
model_file_type, expected_similarity, expected_size,
|
||||||
|
@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase):
|
||||||
@parameterized.parameters(
|
@parameterized.parameters(
|
||||||
# TODO: The similarity should likely be lower
|
# TODO: The similarity should likely be lower
|
||||||
(_BERT_MODEL_FILE, 0.980880),
|
(_BERT_MODEL_FILE, 0.980880),
|
||||||
|
(_USE_MODEL_FILE, 0.780334),
|
||||||
)
|
)
|
||||||
def test_embed_with_different_themes(self, model_file, expected_similarity):
|
def test_embed_with_different_themes(self, model_file, expected_similarity):
|
||||||
# Creates embedder.
|
# Creates embedder.
|
||||||
|
|
|
@ -15,7 +15,6 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
from typing import List
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
@ -30,11 +29,10 @@ from mediapipe.tasks.python.test import test_utils
|
||||||
from mediapipe.tasks.python.vision import image_segmenter
|
from mediapipe.tasks.python.vision import image_segmenter
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
|
ImageSegmenterResult = image_segmenter.ImageSegmenterResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_ImageFormat = image_frame.ImageFormat
|
_ImageFormat = image_frame.ImageFormat
|
||||||
_OutputType = image_segmenter.ImageSegmenterOptions.OutputType
|
|
||||||
_Activation = image_segmenter.ImageSegmenterOptions.Activation
|
|
||||||
_ImageSegmenter = image_segmenter.ImageSegmenter
|
_ImageSegmenter = image_segmenter.ImageSegmenter
|
||||||
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
_ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions
|
||||||
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
_RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||||
|
@ -42,11 +40,33 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode
|
||||||
_MODEL_FILE = 'deeplabv3.tflite'
|
_MODEL_FILE = 'deeplabv3.tflite'
|
||||||
_IMAGE_FILE = 'segmentation_input_rotation0.jpg'
|
_IMAGE_FILE = 'segmentation_input_rotation0.jpg'
|
||||||
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
|
_SEGMENTATION_FILE = 'segmentation_golden_rotation0.png'
|
||||||
|
_CAT_IMAGE = 'cat.jpg'
|
||||||
|
_CAT_MASK = 'cat_mask.jpg'
|
||||||
_MASK_MAGNIFICATION_FACTOR = 10
|
_MASK_MAGNIFICATION_FACTOR = 10
|
||||||
_MASK_SIMILARITY_THRESHOLD = 0.98
|
_MASK_SIMILARITY_THRESHOLD = 0.98
|
||||||
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision'
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_soft_iou(m1, m2):
|
||||||
|
intersection_sum = np.sum(m1 * m2)
|
||||||
|
union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum
|
||||||
|
|
||||||
|
if union_sum > 0:
|
||||||
|
return intersection_sum / union_sum
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold):
|
||||||
|
actual_mask = actual_mask.numpy_view()
|
||||||
|
expected_mask = expected_mask.numpy_view() / 255.0
|
||||||
|
|
||||||
|
return (
|
||||||
|
actual_mask.shape == expected_mask.shape
|
||||||
|
and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _similar_to_uint8_mask(actual_mask, expected_mask):
|
def _similar_to_uint8_mask(actual_mask, expected_mask):
|
||||||
actual_mask_pixels = actual_mask.numpy_view().flatten()
|
actual_mask_pixels = actual_mask.numpy_view().flatten()
|
||||||
expected_mask_pixels = expected_mask.numpy_view().flatten()
|
expected_mask_pixels = expected_mask.numpy_view().flatten()
|
||||||
|
@ -56,8 +76,9 @@ def _similar_to_uint8_mask(actual_mask, expected_mask):
|
||||||
|
|
||||||
for index in range(num_pixels):
|
for index in range(num_pixels):
|
||||||
consistent_pixels += (
|
consistent_pixels += (
|
||||||
actual_mask_pixels[index] *
|
actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR
|
||||||
_MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index])
|
== expected_mask_pixels[index]
|
||||||
|
)
|
||||||
|
|
||||||
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
|
return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD
|
||||||
|
|
||||||
|
@ -73,16 +94,27 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
# Load the test input image.
|
# Load the test input image.
|
||||||
self.test_image = _Image.create_from_file(
|
self.test_image = _Image.create_from_file(
|
||||||
test_utils.get_test_data_path(
|
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))
|
||||||
os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)))
|
)
|
||||||
# Loads ground truth segmentation file.
|
# Loads ground truth segmentation file.
|
||||||
gt_segmentation_data = cv2.imread(
|
gt_segmentation_data = cv2.imread(
|
||||||
test_utils.get_test_data_path(
|
test_utils.get_test_data_path(
|
||||||
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)),
|
os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)
|
||||||
cv2.IMREAD_GRAYSCALE)
|
),
|
||||||
|
cv2.IMREAD_GRAYSCALE,
|
||||||
|
)
|
||||||
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||||
self.model_path = test_utils.get_test_data_path(
|
self.model_path = test_utils.get_test_data_path(
|
||||||
os.path.join(_TEST_DATA_DIR, _MODEL_FILE))
|
os.path.join(_TEST_DATA_DIR, _MODEL_FILE)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_segmentation_mask(self, file_path: str):
|
||||||
|
# Loads ground truth segmentation file.
|
||||||
|
gt_segmentation_data = cv2.imread(
|
||||||
|
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)),
|
||||||
|
cv2.IMREAD_GRAYSCALE,
|
||||||
|
)
|
||||||
|
return _Image(_ImageFormat.GRAY8, gt_segmentation_data)
|
||||||
|
|
||||||
def test_create_from_file_succeeds_with_valid_model_path(self):
|
def test_create_from_file_succeeds_with_valid_model_path(self):
|
||||||
# Creates with default option and valid model file successfully.
|
# Creates with default option and valid model file successfully.
|
||||||
|
@ -98,9 +130,11 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
|
|
||||||
def test_create_from_options_fails_with_invalid_model_path(self):
|
def test_create_from_options_fails_with_invalid_model_path(self):
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'):
|
RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'
|
||||||
|
):
|
||||||
base_options = _BaseOptions(
|
base_options = _BaseOptions(
|
||||||
model_asset_path='/path/to/invalid/model.tflite')
|
model_asset_path='/path/to/invalid/model.tflite'
|
||||||
|
)
|
||||||
options = _ImageSegmenterOptions(base_options=base_options)
|
options = _ImageSegmenterOptions(base_options=base_options)
|
||||||
_ImageSegmenter.create_from_options(options)
|
_ImageSegmenter.create_from_options(options)
|
||||||
|
|
||||||
|
@ -112,8 +146,9 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
segmenter = _ImageSegmenter.create_from_options(options)
|
segmenter = _ImageSegmenter.create_from_options(options)
|
||||||
self.assertIsInstance(segmenter, _ImageSegmenter)
|
self.assertIsInstance(segmenter, _ImageSegmenter)
|
||||||
|
|
||||||
@parameterized.parameters((ModelFileType.FILE_NAME,),
|
@parameterized.parameters(
|
||||||
(ModelFileType.FILE_CONTENT,))
|
(ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,)
|
||||||
|
)
|
||||||
def test_segment_succeeds_with_category_mask(self, model_file_type):
|
def test_segment_succeeds_with_category_mask(self, model_file_type):
|
||||||
# Creates segmenter.
|
# Creates segmenter.
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
if model_file_type is ModelFileType.FILE_NAME:
|
||||||
|
@ -127,22 +162,27 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
base_options=base_options,
|
||||||
|
output_category_mask=True,
|
||||||
|
output_confidence_masks=False,
|
||||||
|
)
|
||||||
segmenter = _ImageSegmenter.create_from_options(options)
|
segmenter = _ImageSegmenter.create_from_options(options)
|
||||||
|
|
||||||
# Performs image segmentation on the input.
|
# Performs image segmentation on the input.
|
||||||
category_masks = segmenter.segment(self.test_image)
|
segmentation_result = segmenter.segment(self.test_image)
|
||||||
self.assertLen(category_masks, 1)
|
category_mask = segmentation_result.category_mask
|
||||||
category_mask = category_masks[0]
|
|
||||||
result_pixels = category_mask.numpy_view().flatten()
|
result_pixels = category_mask.numpy_view().flatten()
|
||||||
|
|
||||||
# Check if data type of `category_mask` is correct.
|
# Check if data type of `category_mask` is correct.
|
||||||
self.assertEqual(result_pixels.dtype, np.uint8)
|
self.assertEqual(result_pixels.dtype, np.uint8)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
|
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||||
f'Number of pixels in the candidate mask differing from that of the '
|
(
|
||||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
'Number of pixels in the candidate mask differing from that of the'
|
||||||
|
f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Closes the segmenter explicitly when the segmenter is not used in
|
# Closes the segmenter explicitly when the segmenter is not used in
|
||||||
# a context.
|
# a context.
|
||||||
|
@ -152,74 +192,46 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
# Creates segmenter.
|
# Creates segmenter.
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
base_options = _BaseOptions(model_asset_path=self.model_path)
|
||||||
|
|
||||||
# Run segmentation on the model in CATEGORY_MASK mode.
|
# Load the cat image.
|
||||||
options = _ImageSegmenterOptions(
|
test_image = _Image.create_from_file(
|
||||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||||
segmenter = _ImageSegmenter.create_from_options(options)
|
)
|
||||||
category_masks = segmenter.segment(self.test_image)
|
|
||||||
category_mask = category_masks[0].numpy_view()
|
|
||||||
|
|
||||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=base_options,
|
base_options=base_options,
|
||||||
output_type=_OutputType.CONFIDENCE_MASK,
|
output_category_mask=False,
|
||||||
activation=_Activation.SOFTMAX)
|
output_confidence_masks=True,
|
||||||
segmenter = _ImageSegmenter.create_from_options(options)
|
)
|
||||||
confidence_masks = segmenter.segment(self.test_image)
|
|
||||||
|
|
||||||
# Check if confidence mask shape is correct.
|
|
||||||
self.assertLen(
|
|
||||||
confidence_masks, 21,
|
|
||||||
'Number of confidence masks must match with number of categories.')
|
|
||||||
|
|
||||||
# Gather the confidence masks in a single array `confidence_mask_array`.
|
|
||||||
confidence_mask_array = np.array(
|
|
||||||
[confidence_mask.numpy_view() for confidence_mask in confidence_masks])
|
|
||||||
|
|
||||||
# Check if data type of `confidence_masks` are correct.
|
|
||||||
self.assertEqual(confidence_mask_array.dtype, np.float32)
|
|
||||||
|
|
||||||
# Compute the category mask from the created confidence mask.
|
|
||||||
calculated_category_mask = np.argmax(confidence_mask_array, axis=0)
|
|
||||||
self.assertListEqual(
|
|
||||||
calculated_category_mask.tolist(), category_mask.tolist(),
|
|
||||||
'Confidence mask does not match with the category mask.')
|
|
||||||
|
|
||||||
# Closes the segmenter explicitly when the segmenter is not used in
|
|
||||||
# a context.
|
|
||||||
segmenter.close()
|
|
||||||
|
|
||||||
@parameterized.parameters((ModelFileType.FILE_NAME),
|
|
||||||
(ModelFileType.FILE_CONTENT))
|
|
||||||
def test_segment_in_context(self, model_file_type):
|
|
||||||
if model_file_type is ModelFileType.FILE_NAME:
|
|
||||||
base_options = _BaseOptions(model_asset_path=self.model_path)
|
|
||||||
elif model_file_type is ModelFileType.FILE_CONTENT:
|
|
||||||
with open(self.model_path, 'rb') as f:
|
|
||||||
model_contents = f.read()
|
|
||||||
base_options = _BaseOptions(model_asset_buffer=model_contents)
|
|
||||||
else:
|
|
||||||
# Should never happen
|
|
||||||
raise ValueError('model_file_type is invalid.')
|
|
||||||
|
|
||||||
options = _ImageSegmenterOptions(
|
|
||||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK)
|
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
# Performs image segmentation on the input.
|
segmentation_result = segmenter.segment(test_image)
|
||||||
category_masks = segmenter.segment(self.test_image)
|
confidence_masks = segmentation_result.confidence_masks
|
||||||
self.assertLen(category_masks, 1)
|
|
||||||
|
# Check if confidence mask shape is correct.
|
||||||
|
self.assertLen(
|
||||||
|
confidence_masks,
|
||||||
|
21,
|
||||||
|
'Number of confidence masks must match with number of categories.',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loads ground truth segmentation file.
|
||||||
|
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
|
_similar_to_float_mask(
|
||||||
f'Number of pixels in the candidate mask differing from that of the '
|
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_missing_result_callback(self):
|
def test_missing_result_callback(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM)
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
with self.assertRaisesRegex(ValueError,
|
)
|
||||||
r'result callback must be provided'):
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'result callback must be provided'
|
||||||
|
):
|
||||||
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -228,130 +240,236 @@ class ImageSegmenterTest(parameterized.TestCase):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=running_mode,
|
running_mode=running_mode,
|
||||||
result_callback=mock.MagicMock())
|
result_callback=mock.MagicMock(),
|
||||||
with self.assertRaisesRegex(ValueError,
|
)
|
||||||
r'result callback should not be provided'):
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, r'result callback should not be provided'
|
||||||
|
):
|
||||||
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
with _ImageSegmenter.create_from_options(options) as unused_segmenter:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def test_calling_segment_for_video_in_image_mode(self):
|
def test_calling_segment_for_video_in_image_mode(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.IMAGE)
|
running_mode=_RUNNING_MODE.IMAGE,
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(
|
||||||
r'not initialized with the video mode'):
|
ValueError, r'not initialized with the video mode'
|
||||||
|
):
|
||||||
segmenter.segment_for_video(self.test_image, 0)
|
segmenter.segment_for_video(self.test_image, 0)
|
||||||
|
|
||||||
def test_calling_segment_async_in_image_mode(self):
|
def test_calling_segment_async_in_image_mode(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.IMAGE)
|
running_mode=_RUNNING_MODE.IMAGE,
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(
|
||||||
r'not initialized with the live stream mode'):
|
ValueError, r'not initialized with the live stream mode'
|
||||||
|
):
|
||||||
segmenter.segment_async(self.test_image, 0)
|
segmenter.segment_async(self.test_image, 0)
|
||||||
|
|
||||||
def test_calling_segment_in_video_mode(self):
|
def test_calling_segment_in_video_mode(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.VIDEO)
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(
|
||||||
r'not initialized with the image mode'):
|
ValueError, r'not initialized with the image mode'
|
||||||
|
):
|
||||||
segmenter.segment(self.test_image)
|
segmenter.segment(self.test_image)
|
||||||
|
|
||||||
def test_calling_segment_async_in_video_mode(self):
|
def test_calling_segment_async_in_video_mode(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.VIDEO)
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(
|
||||||
r'not initialized with the live stream mode'):
|
ValueError, r'not initialized with the live stream mode'
|
||||||
|
):
|
||||||
segmenter.segment_async(self.test_image, 0)
|
segmenter.segment_async(self.test_image, 0)
|
||||||
|
|
||||||
def test_segment_for_video_with_out_of_order_timestamp(self):
|
def test_segment_for_video_with_out_of_order_timestamp(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.VIDEO)
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
unused_result = segmenter.segment_for_video(self.test_image, 1)
|
unused_result = segmenter.segment_for_video(self.test_image, 1)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
ValueError, r'Input timestamp must be monotonically increasing'
|
||||||
|
):
|
||||||
segmenter.segment_for_video(self.test_image, 0)
|
segmenter.segment_for_video(self.test_image, 0)
|
||||||
|
|
||||||
def test_segment_for_video(self):
|
def test_segment_for_video_in_category_mask_mode(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
output_type=_OutputType.CATEGORY_MASK,
|
output_category_mask=True,
|
||||||
running_mode=_RUNNING_MODE.VIDEO)
|
output_confidence_masks=False,
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
category_masks = segmenter.segment_for_video(self.test_image, timestamp)
|
segmentation_result = segmenter.segment_for_video(
|
||||||
self.assertLen(category_masks, 1)
|
self.test_image, timestamp
|
||||||
|
)
|
||||||
|
category_mask = segmentation_result.category_mask
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
_similar_to_uint8_mask(category_masks[0], self.test_seg_image),
|
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||||
f'Number of pixels in the candidate mask differing from that of the '
|
(
|
||||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
'Number of pixels in the candidate mask differing from that of'
|
||||||
|
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_segment_for_video_in_confidence_mask_mode(self):
|
||||||
|
# Load the cat image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||||
|
)
|
||||||
|
|
||||||
|
options = _ImageSegmenterOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.VIDEO,
|
||||||
|
output_category_mask=False,
|
||||||
|
output_confidence_masks=True,
|
||||||
|
)
|
||||||
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
segmentation_result = segmenter.segment_for_video(test_image, timestamp)
|
||||||
|
confidence_masks = segmentation_result.confidence_masks
|
||||||
|
|
||||||
|
# Check if confidence mask shape is correct.
|
||||||
|
self.assertLen(
|
||||||
|
confidence_masks,
|
||||||
|
21,
|
||||||
|
'Number of confidence masks must match with number of categories.',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loads ground truth segmentation file.
|
||||||
|
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||||
|
self.assertTrue(
|
||||||
|
_similar_to_float_mask(
|
||||||
|
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def test_calling_segment_in_live_stream_mode(self):
|
def test_calling_segment_in_live_stream_mode(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
result_callback=mock.MagicMock())
|
result_callback=mock.MagicMock(),
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(
|
||||||
r'not initialized with the image mode'):
|
ValueError, r'not initialized with the image mode'
|
||||||
|
):
|
||||||
segmenter.segment(self.test_image)
|
segmenter.segment(self.test_image)
|
||||||
|
|
||||||
def test_calling_segment_for_video_in_live_stream_mode(self):
|
def test_calling_segment_for_video_in_live_stream_mode(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
result_callback=mock.MagicMock())
|
result_callback=mock.MagicMock(),
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(
|
||||||
r'not initialized with the video mode'):
|
ValueError, r'not initialized with the video mode'
|
||||||
|
):
|
||||||
segmenter.segment_for_video(self.test_image, 0)
|
segmenter.segment_for_video(self.test_image, 0)
|
||||||
|
|
||||||
def test_segment_async_calls_with_illegal_timestamp(self):
|
def test_segment_async_calls_with_illegal_timestamp(self):
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
result_callback=mock.MagicMock())
|
result_callback=mock.MagicMock(),
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
segmenter.segment_async(self.test_image, 100)
|
segmenter.segment_async(self.test_image, 100)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, r'Input timestamp must be monotonically increasing'):
|
ValueError, r'Input timestamp must be monotonically increasing'
|
||||||
|
):
|
||||||
segmenter.segment_async(self.test_image, 0)
|
segmenter.segment_async(self.test_image, 0)
|
||||||
|
|
||||||
def test_segment_async_calls(self):
|
def test_segment_async_calls_in_category_mask_mode(self):
|
||||||
observed_timestamp_ms = -1
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
def check_result(result: List[image_module.Image], output_image: _Image,
|
def check_result(
|
||||||
timestamp_ms: int):
|
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
|
||||||
|
):
|
||||||
# Get the output category mask.
|
# Get the output category mask.
|
||||||
category_mask = result[0]
|
category_mask = result.category_mask
|
||||||
self.assertEqual(output_image.width, self.test_image.width)
|
self.assertEqual(output_image.width, self.test_image.width)
|
||||||
self.assertEqual(output_image.height, self.test_image.height)
|
self.assertEqual(output_image.height, self.test_image.height)
|
||||||
self.assertEqual(output_image.width, self.test_seg_image.width)
|
self.assertEqual(output_image.width, self.test_seg_image.width)
|
||||||
self.assertEqual(output_image.height, self.test_seg_image.height)
|
self.assertEqual(output_image.height, self.test_seg_image.height)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
_similar_to_uint8_mask(category_mask, self.test_seg_image),
|
||||||
f'Number of pixels in the candidate mask differing from that of the '
|
(
|
||||||
f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.')
|
'Number of pixels in the candidate mask differing from that of'
|
||||||
|
f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.'
|
||||||
|
),
|
||||||
|
)
|
||||||
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
self.observed_timestamp_ms = timestamp_ms
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
options = _ImageSegmenterOptions(
|
options = _ImageSegmenterOptions(
|
||||||
base_options=_BaseOptions(model_asset_path=self.model_path),
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
output_type=_OutputType.CATEGORY_MASK,
|
output_category_mask=True,
|
||||||
|
output_confidence_masks=False,
|
||||||
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
result_callback=check_result)
|
result_callback=check_result,
|
||||||
|
)
|
||||||
with _ImageSegmenter.create_from_options(options) as segmenter:
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
for timestamp in range(0, 300, 30):
|
for timestamp in range(0, 300, 30):
|
||||||
segmenter.segment_async(self.test_image, timestamp)
|
segmenter.segment_async(self.test_image, timestamp)
|
||||||
|
|
||||||
|
def test_segment_async_calls_in_confidence_mask_mode(self):
|
||||||
|
# Load the cat image.
|
||||||
|
test_image = _Image.create_from_file(
|
||||||
|
test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loads ground truth segmentation file.
|
||||||
|
expected_mask = self._load_segmentation_mask(_CAT_MASK)
|
||||||
|
observed_timestamp_ms = -1
|
||||||
|
|
||||||
|
def check_result(
|
||||||
|
result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int
|
||||||
|
):
|
||||||
|
# Get the output category mask.
|
||||||
|
confidence_masks = result.confidence_masks
|
||||||
|
|
||||||
|
# Check if confidence mask shape is correct.
|
||||||
|
self.assertLen(
|
||||||
|
confidence_masks,
|
||||||
|
21,
|
||||||
|
'Number of confidence masks must match with number of categories.',
|
||||||
|
)
|
||||||
|
self.assertEqual(output_image.width, test_image.width)
|
||||||
|
self.assertEqual(output_image.height, test_image.height)
|
||||||
|
self.assertTrue(
|
||||||
|
_similar_to_float_mask(
|
||||||
|
confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertLess(observed_timestamp_ms, timestamp_ms)
|
||||||
|
self.observed_timestamp_ms = timestamp_ms
|
||||||
|
|
||||||
|
options = _ImageSegmenterOptions(
|
||||||
|
base_options=_BaseOptions(model_asset_path=self.model_path),
|
||||||
|
running_mode=_RUNNING_MODE.LIVE_STREAM,
|
||||||
|
output_category_mask=False,
|
||||||
|
output_confidence_masks=True,
|
||||||
|
result_callback=check_result,
|
||||||
|
)
|
||||||
|
with _ImageSegmenter.create_from_options(options) as segmenter:
|
||||||
|
for timestamp in range(0, 300, 30):
|
||||||
|
segmenter.segment_async(test_image, timestamp)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils
|
||||||
from mediapipe.tasks.python.vision import interactive_segmenter
|
from mediapipe.tasks.python.vision import interactive_segmenter
|
||||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
|
|
||||||
|
InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_Image = image_module.Image
|
_Image = image_module.Image
|
||||||
_ImageFormat = image_frame.ImageFormat
|
_ImageFormat = image_frame.ImageFormat
|
||||||
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint
|
_NormalizedKeypoint = keypoint_module.NormalizedKeypoint
|
||||||
_Rect = rect.Rect
|
_Rect = rect.Rect
|
||||||
_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType
|
|
||||||
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
|
_InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter
|
||||||
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
|
_InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions
|
||||||
_RegionOfInterest = interactive_segmenter.RegionOfInterest
|
_RegionOfInterest = interactive_segmenter.RegionOfInterest
|
||||||
|
@ -200,15 +200,16 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
||||||
raise ValueError('model_file_type is invalid.')
|
raise ValueError('model_file_type is invalid.')
|
||||||
|
|
||||||
options = _InteractiveSegmenterOptions(
|
options = _InteractiveSegmenterOptions(
|
||||||
base_options=base_options, output_type=_OutputType.CATEGORY_MASK
|
base_options=base_options,
|
||||||
|
output_category_mask=True,
|
||||||
|
output_confidence_masks=False,
|
||||||
)
|
)
|
||||||
segmenter = _InteractiveSegmenter.create_from_options(options)
|
segmenter = _InteractiveSegmenter.create_from_options(options)
|
||||||
|
|
||||||
# Performs image segmentation on the input.
|
# Performs image segmentation on the input.
|
||||||
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
|
roi = _RegionOfInterest(format=roi_format, keypoint=keypoint)
|
||||||
category_masks = segmenter.segment(self.test_image, roi)
|
segmentation_result = segmenter.segment(self.test_image, roi)
|
||||||
self.assertLen(category_masks, 1)
|
category_mask = segmentation_result.category_mask
|
||||||
category_mask = category_masks[0]
|
|
||||||
result_pixels = category_mask.numpy_view().flatten()
|
result_pixels = category_mask.numpy_view().flatten()
|
||||||
|
|
||||||
# Check if data type of `category_mask` is correct.
|
# Check if data type of `category_mask` is correct.
|
||||||
|
@ -219,7 +220,7 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
_similar_to_uint8_mask(
|
_similar_to_uint8_mask(
|
||||||
category_masks[0], test_seg_image, similarity_threshold
|
category_mask, test_seg_image, similarity_threshold
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
'Number of pixels in the candidate mask differing from that of the'
|
'Number of pixels in the candidate mask differing from that of the'
|
||||||
|
@ -254,12 +255,15 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||||
options = _InteractiveSegmenterOptions(
|
options = _InteractiveSegmenterOptions(
|
||||||
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
|
base_options=base_options,
|
||||||
|
output_category_mask=False,
|
||||||
|
output_confidence_masks=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||||
# Perform segmentation
|
# Perform segmentation
|
||||||
confidence_masks = segmenter.segment(self.test_image, roi)
|
segmentation_result = segmenter.segment(self.test_image, roi)
|
||||||
|
confidence_masks = segmentation_result.confidence_masks
|
||||||
|
|
||||||
# Check if confidence mask shape is correct.
|
# Check if confidence mask shape is correct.
|
||||||
self.assertLen(
|
self.assertLen(
|
||||||
|
@ -287,15 +291,18 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||||
options = _InteractiveSegmenterOptions(
|
options = _InteractiveSegmenterOptions(
|
||||||
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
|
base_options=base_options,
|
||||||
|
output_category_mask=False,
|
||||||
|
output_confidence_masks=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
with _InteractiveSegmenter.create_from_options(options) as segmenter:
|
||||||
# Perform segmentation
|
# Perform segmentation
|
||||||
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
image_processing_options = _ImageProcessingOptions(rotation_degrees=-90)
|
||||||
confidence_masks = segmenter.segment(
|
segmentation_result = segmenter.segment(
|
||||||
self.test_image, roi, image_processing_options
|
self.test_image, roi, image_processing_options
|
||||||
)
|
)
|
||||||
|
confidence_masks = segmentation_result.confidence_masks
|
||||||
|
|
||||||
# Check if confidence mask shape is correct.
|
# Check if confidence mask shape is correct.
|
||||||
self.assertLen(
|
self.assertLen(
|
||||||
|
@ -314,7 +321,9 @@ class InteractiveSegmenterTest(parameterized.TestCase):
|
||||||
|
|
||||||
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
# Run segmentation on the model in CONFIDENCE_MASK mode.
|
||||||
options = _InteractiveSegmenterOptions(
|
options = _InteractiveSegmenterOptions(
|
||||||
base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK
|
base_options=base_options,
|
||||||
|
output_category_mask=False,
|
||||||
|
output_confidence_masks=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
|
|
|
@ -32,6 +32,7 @@ FaceDetectorResult = face_detector.FaceDetectorResult
|
||||||
FaceLandmarker = face_landmarker.FaceLandmarker
|
FaceLandmarker = face_landmarker.FaceLandmarker
|
||||||
FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions
|
FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions
|
||||||
FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult
|
FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult
|
||||||
|
FaceLandmarksConnections = face_landmarker.FaceLandmarksConnections
|
||||||
FaceStylizer = face_stylizer.FaceStylizer
|
FaceStylizer = face_stylizer.FaceStylizer
|
||||||
FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
FaceStylizerOptions = face_stylizer.FaceStylizerOptions
|
||||||
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
GestureRecognizer = gesture_recognizer.GestureRecognizer
|
||||||
|
|
|
@ -208,6 +208,11 @@ class BaseVisionTaskApi(object):
|
||||||
"""
|
"""
|
||||||
self._runner.close()
|
self._runner.close()
|
||||||
|
|
||||||
|
def get_graph_config(self) -> calculator_pb2.CalculatorGraphConfig:
|
||||||
|
"""Returns the canonicalized CalculatorGraphConfig of the underlying graph.
|
||||||
|
"""
|
||||||
|
return self._runner.get_graph_config()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Return `self` upon entering the runtime context."""
|
"""Return `self` upon entering the runtime context."""
|
||||||
return self
|
return self
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -176,16 +176,13 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
Only use this method when the FaceStylizer is created with the image
|
Only use this method when the FaceStylizer is created with the image
|
||||||
running mode.
|
running mode.
|
||||||
|
|
||||||
To ensure that the output image has reasonable quality, the stylized output
|
|
||||||
image size is the smaller of the model output size and the size of the
|
|
||||||
`region_of_interest` specified in `image_processing_options`.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
image_processing_options: Options for image processing.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The stylized image of the most visible face. None if no face is detected
|
The stylized image of the most visible face. The stylized output image
|
||||||
|
size is the same as the model output size. None if no face is detected
|
||||||
on the input image.
|
on the input image.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -217,17 +214,14 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
milliseconds) along with the video frame. The input timestamps should be
|
milliseconds) along with the video frame. The input timestamps should be
|
||||||
monotonically increasing for adjacent calls of this method.
|
monotonically increasing for adjacent calls of this method.
|
||||||
|
|
||||||
To ensure that the output image has reasonable quality, the stylized output
|
|
||||||
image size is the smaller of the model output size and the size of the
|
|
||||||
`region_of_interest` specified in `image_processing_options`.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image: MediaPipe Image.
|
image: MediaPipe Image.
|
||||||
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
timestamp_ms: The timestamp of the input video frame in milliseconds.
|
||||||
image_processing_options: Options for image processing.
|
image_processing_options: Options for image processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The stylized image of the most visible face. None if no face is detected
|
The stylized image of the most visible face. The stylized output image
|
||||||
|
size is the same as the model output size. None if no face is detected
|
||||||
on the input image.
|
on the input image.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -266,12 +260,9 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi):
|
||||||
images if needed. In other words, it's not guaranteed to have output per
|
images if needed. In other words, it's not guaranteed to have output per
|
||||||
input image.
|
input image.
|
||||||
|
|
||||||
To ensure that the stylized image has reasonable quality, the stylized
|
|
||||||
output image size is the smaller of the model output size and the size of
|
|
||||||
the `region_of_interest` specified in `image_processing_options`.
|
|
||||||
|
|
||||||
The `result_callback` provides:
|
The `result_callback` provides:
|
||||||
- The stylized image of the most visible face. None if no face is detected
|
- The stylized image of the most visible face. The stylized output image
|
||||||
|
size is the same as the model output size. None if no face is detected
|
||||||
on the input image.
|
on the input image.
|
||||||
- The input image that the face stylizer runs on.
|
- The input image that the face stylizer runs on.
|
||||||
- The input timestamp in milliseconds.
|
- The input timestamp in milliseconds.
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
"""MediaPipe image segmenter task."""
|
"""MediaPipe image segmenter task."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
|
||||||
from typing import Callable, List, Mapping, Optional
|
from typing import Callable, List, Mapping, Optional
|
||||||
|
|
||||||
from mediapipe.python import packet_creator
|
from mediapipe.python import packet_creator
|
||||||
|
@ -31,7 +30,6 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api
|
||||||
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module
|
||||||
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
from mediapipe.tasks.python.vision.core import vision_task_running_mode
|
||||||
|
|
||||||
ImageSegmenterResult = List[image_module.Image]
|
|
||||||
_NormalizedRect = rect.NormalizedRect
|
_NormalizedRect = rect.NormalizedRect
|
||||||
_BaseOptions = base_options_module.BaseOptions
|
_BaseOptions = base_options_module.BaseOptions
|
||||||
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
|
_SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions
|
||||||
|
@ -42,8 +40,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
|
_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
|
||||||
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
|
||||||
|
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
|
||||||
|
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_IMAGE_TAG = 'IMAGE'
|
_IMAGE_TAG = 'IMAGE'
|
||||||
|
@ -53,6 +53,21 @@ _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'
|
||||||
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
_MICRO_SECONDS_PER_MILLISECOND = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ImageSegmenterResult:
|
||||||
|
"""Output result of ImageSegmenter.
|
||||||
|
|
||||||
|
confidence_masks: multiple masks of float image where, for each mask, each
|
||||||
|
pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||||
|
|
||||||
|
category_mask: a category mask of uint8 image where each pixel represents the
|
||||||
|
class which the pixel in the original image was predicted to belong to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
confidence_masks: Optional[List[image_module.Image]] = None
|
||||||
|
category_mask: Optional[image_module.Image] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class ImageSegmenterOptions:
|
class ImageSegmenterOptions:
|
||||||
"""Options for the image segmenter task.
|
"""Options for the image segmenter task.
|
||||||
|
@ -64,28 +79,17 @@ class ImageSegmenterOptions:
|
||||||
objects on single image inputs. 2) The video mode for segmenting objects
|
objects on single image inputs. 2) The video mode for segmenting objects
|
||||||
on the decoded frames of a video. 3) The live stream mode for segmenting
|
on the decoded frames of a video. 3) The live stream mode for segmenting
|
||||||
objects on a live stream of input data, such as from camera.
|
objects on a live stream of input data, such as from camera.
|
||||||
output_type: The output mask type allows specifying the type of
|
output_confidence_masks: Whether to output confidence masks.
|
||||||
post-processing to perform on the raw model results.
|
output_category_mask: Whether to output category mask.
|
||||||
activation: Activation function to apply to input tensor.
|
|
||||||
result_callback: The user-defined result callback for processing live stream
|
result_callback: The user-defined result callback for processing live stream
|
||||||
data. The result callback should only be specified when the running mode
|
data. The result callback should only be specified when the running mode
|
||||||
is set to the live stream mode.
|
is set to the live stream mode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class OutputType(enum.Enum):
|
|
||||||
UNSPECIFIED = 0
|
|
||||||
CATEGORY_MASK = 1
|
|
||||||
CONFIDENCE_MASK = 2
|
|
||||||
|
|
||||||
class Activation(enum.Enum):
|
|
||||||
NONE = 0
|
|
||||||
SIGMOID = 1
|
|
||||||
SOFTMAX = 2
|
|
||||||
|
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
running_mode: _RunningMode = _RunningMode.IMAGE
|
running_mode: _RunningMode = _RunningMode.IMAGE
|
||||||
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
output_confidence_masks: bool = True
|
||||||
activation: Optional[Activation] = Activation.NONE
|
output_category_mask: bool = False
|
||||||
result_callback: Optional[
|
result_callback: Optional[
|
||||||
Callable[[ImageSegmenterResult, image_module.Image, int], None]
|
Callable[[ImageSegmenterResult, image_module.Image, int], None]
|
||||||
] = None
|
] = None
|
||||||
|
@ -97,9 +101,7 @@ class ImageSegmenterOptions:
|
||||||
base_options_proto.use_stream_mode = (
|
base_options_proto.use_stream_mode = (
|
||||||
False if self.running_mode == _RunningMode.IMAGE else True
|
False if self.running_mode == _RunningMode.IMAGE else True
|
||||||
)
|
)
|
||||||
segmenter_options_proto = _SegmenterOptionsProto(
|
segmenter_options_proto = _SegmenterOptionsProto()
|
||||||
output_type=self.output_type.value, activation=self.activation.value
|
|
||||||
)
|
|
||||||
return _ImageSegmenterGraphOptionsProto(
|
return _ImageSegmenterGraphOptionsProto(
|
||||||
base_options=base_options_proto,
|
base_options=base_options_proto,
|
||||||
segmenter_options=segmenter_options_proto,
|
segmenter_options=segmenter_options_proto,
|
||||||
|
@ -177,27 +179,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
def packets_callback(output_packets: Mapping[str, packet.Packet]):
|
||||||
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty():
|
||||||
return
|
return
|
||||||
segmentation_result = packet_getter.get_image_list(
|
|
||||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
segmentation_result = ImageSegmenterResult()
|
||||||
)
|
|
||||||
|
if options.output_confidence_masks:
|
||||||
|
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||||
|
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
|
if options.output_category_mask:
|
||||||
|
segmentation_result.category_mask = packet_getter.get_image(
|
||||||
|
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME])
|
||||||
timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp
|
timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp
|
||||||
options.result_callback(
|
options.result_callback(
|
||||||
segmentation_result,
|
segmentation_result,
|
||||||
image,
|
image,
|
||||||
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
timestamp.value // _MICRO_SECONDS_PER_MILLISECOND,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
output_streams = [
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||||
|
]
|
||||||
|
|
||||||
|
if options.output_confidence_masks:
|
||||||
|
output_streams.append(
|
||||||
|
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
|
||||||
|
)
|
||||||
|
|
||||||
|
if options.output_category_mask:
|
||||||
|
output_streams.append(
|
||||||
|
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
|
||||||
|
)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[
|
input_streams=[
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]),
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=output_streams,
|
||||||
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
|
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
|
||||||
],
|
|
||||||
task_options=options,
|
task_options=options,
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -240,9 +263,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
normalized_rect.to_pb2()
|
normalized_rect.to_pb2()
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
segmentation_result = packet_getter.get_image_list(
|
segmentation_result = ImageSegmenterResult()
|
||||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
|
||||||
)
|
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
|
||||||
|
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||||
|
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
|
if _CATEGORY_MASK_STREAM_NAME in output_packets:
|
||||||
|
segmentation_result.category_mask = packet_getter.get_image(
|
||||||
|
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
return segmentation_result
|
return segmentation_result
|
||||||
|
|
||||||
def segment_for_video(
|
def segment_for_video(
|
||||||
|
@ -285,9 +317,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
normalized_rect.to_pb2()
|
normalized_rect.to_pb2()
|
||||||
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND),
|
||||||
})
|
})
|
||||||
segmentation_result = packet_getter.get_image_list(
|
segmentation_result = ImageSegmenterResult()
|
||||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
|
||||||
)
|
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
|
||||||
|
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||||
|
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
|
if _CATEGORY_MASK_STREAM_NAME in output_packets:
|
||||||
|
segmentation_result.category_mask = packet_getter.get_image(
|
||||||
|
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
return segmentation_result
|
return segmentation_result
|
||||||
|
|
||||||
def segment_async(
|
def segment_async(
|
||||||
|
|
|
@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode
|
||||||
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions
|
||||||
_TaskInfo = task_info_module.TaskInfo
|
_TaskInfo = task_info_module.TaskInfo
|
||||||
|
|
||||||
_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out'
|
_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks'
|
||||||
_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION'
|
_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS'
|
||||||
|
_CATEGORY_MASK_STREAM_NAME = 'category_mask'
|
||||||
|
_CATEGORY_MASK_TAG = 'CATEGORY_MASK'
|
||||||
_IMAGE_IN_STREAM_NAME = 'image_in'
|
_IMAGE_IN_STREAM_NAME = 'image_in'
|
||||||
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
_IMAGE_OUT_STREAM_NAME = 'image_out'
|
||||||
_ROI_STREAM_NAME = 'roi_in'
|
_ROI_STREAM_NAME = 'roi_in'
|
||||||
|
@ -55,32 +57,41 @@ _TASK_GRAPH_NAME = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class InteractiveSegmenterResult:
|
||||||
|
"""Output result of InteractiveSegmenter.
|
||||||
|
|
||||||
|
confidence_masks: multiple masks of float image where, for each mask, each
|
||||||
|
pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||||
|
|
||||||
|
category_mask: a category mask of uint8 image where each pixel represents the
|
||||||
|
class which the pixel in the original image was predicted to belong to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
confidence_masks: Optional[List[image_module.Image]] = None
|
||||||
|
category_mask: Optional[image_module.Image] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class InteractiveSegmenterOptions:
|
class InteractiveSegmenterOptions:
|
||||||
"""Options for the interactive segmenter task.
|
"""Options for the interactive segmenter task.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
base_options: Base options for the interactive segmenter task.
|
base_options: Base options for the interactive segmenter task.
|
||||||
output_type: The output mask type allows specifying the type of
|
output_confidence_masks: Whether to output confidence masks.
|
||||||
post-processing to perform on the raw model results.
|
output_category_mask: Whether to output category mask.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class OutputType(enum.Enum):
|
|
||||||
UNSPECIFIED = 0
|
|
||||||
CATEGORY_MASK = 1
|
|
||||||
CONFIDENCE_MASK = 2
|
|
||||||
|
|
||||||
base_options: _BaseOptions
|
base_options: _BaseOptions
|
||||||
output_type: Optional[OutputType] = OutputType.CATEGORY_MASK
|
output_confidence_masks: bool = True
|
||||||
|
output_category_mask: bool = False
|
||||||
|
|
||||||
@doc_controls.do_not_generate_docs
|
@doc_controls.do_not_generate_docs
|
||||||
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
def to_pb2(self) -> _ImageSegmenterGraphOptionsProto:
|
||||||
"""Generates an InteractiveSegmenterOptions protobuf object."""
|
"""Generates an InteractiveSegmenterOptions protobuf object."""
|
||||||
base_options_proto = self.base_options.to_pb2()
|
base_options_proto = self.base_options.to_pb2()
|
||||||
base_options_proto.use_stream_mode = False
|
base_options_proto.use_stream_mode = False
|
||||||
segmenter_options_proto = _SegmenterOptionsProto(
|
segmenter_options_proto = _SegmenterOptionsProto()
|
||||||
output_type=self.output_type.value
|
|
||||||
)
|
|
||||||
return _ImageSegmenterGraphOptionsProto(
|
return _ImageSegmenterGraphOptionsProto(
|
||||||
base_options=base_options_proto,
|
base_options=base_options_proto,
|
||||||
segmenter_options=segmenter_options_proto,
|
segmenter_options=segmenter_options_proto,
|
||||||
|
@ -192,6 +203,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
RuntimeError: If other types of error occurred.
|
RuntimeError: If other types of error occurred.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
output_streams = [
|
||||||
|
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
||||||
|
]
|
||||||
|
|
||||||
|
if options.output_confidence_masks:
|
||||||
|
output_streams.append(
|
||||||
|
':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME])
|
||||||
|
)
|
||||||
|
|
||||||
|
if options.output_category_mask:
|
||||||
|
output_streams.append(
|
||||||
|
':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME])
|
||||||
|
)
|
||||||
|
|
||||||
task_info = _TaskInfo(
|
task_info = _TaskInfo(
|
||||||
task_graph=_TASK_GRAPH_NAME,
|
task_graph=_TASK_GRAPH_NAME,
|
||||||
input_streams=[
|
input_streams=[
|
||||||
|
@ -199,10 +224,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
':'.join([_ROI_TAG, _ROI_STREAM_NAME]),
|
':'.join([_ROI_TAG, _ROI_STREAM_NAME]),
|
||||||
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]),
|
||||||
],
|
],
|
||||||
output_streams=[
|
output_streams=output_streams,
|
||||||
':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]),
|
|
||||||
':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]),
|
|
||||||
],
|
|
||||||
task_options=options,
|
task_options=options,
|
||||||
)
|
)
|
||||||
return cls(
|
return cls(
|
||||||
|
@ -216,7 +238,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
image: image_module.Image,
|
image: image_module.Image,
|
||||||
roi: RegionOfInterest,
|
roi: RegionOfInterest,
|
||||||
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
image_processing_options: Optional[_ImageProcessingOptions] = None,
|
||||||
) -> List[image_module.Image]:
|
) -> InteractiveSegmenterResult:
|
||||||
"""Performs the actual segmentation task on the provided MediaPipe Image.
|
"""Performs the actual segmentation task on the provided MediaPipe Image.
|
||||||
|
|
||||||
The image can be of any size with format RGB.
|
The image can be of any size with format RGB.
|
||||||
|
@ -248,7 +270,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi):
|
||||||
normalized_rect.to_pb2()
|
normalized_rect.to_pb2()
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
segmentation_result = packet_getter.get_image_list(
|
segmentation_result = InteractiveSegmenterResult()
|
||||||
output_packets[_SEGMENTATION_OUT_STREAM_NAME]
|
|
||||||
)
|
if _CONFIDENCE_MASKS_STREAM_NAME in output_packets:
|
||||||
|
segmentation_result.confidence_masks = packet_getter.get_image_list(
|
||||||
|
output_packets[_CONFIDENCE_MASKS_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
|
if _CATEGORY_MASK_STREAM_NAME in output_packets:
|
||||||
|
segmentation_result.category_mask = packet_getter.get_image(
|
||||||
|
output_packets[_CATEGORY_MASK_STREAM_NAME]
|
||||||
|
)
|
||||||
|
|
||||||
return segmentation_result
|
return segmentation_result
|
||||||
|
|
|
@ -59,13 +59,12 @@ export function drawCategoryMask(
|
||||||
const isFloatArray = image instanceof Float32Array;
|
const isFloatArray = image instanceof Float32Array;
|
||||||
for (let i = 0; i < image.length; i++) {
|
for (let i = 0; i < image.length; i++) {
|
||||||
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
||||||
const color = COLOR_MAP[colorIndex];
|
let color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
||||||
|
|
||||||
// When we're given a confidence mask by accident, we just log and return.
|
|
||||||
// TODO: We should fix this.
|
|
||||||
if (!color) {
|
if (!color) {
|
||||||
|
// TODO: We should fix this.
|
||||||
console.warn('No color for ', colorIndex);
|
console.warn('No color for ', colorIndex);
|
||||||
return;
|
color = COLOR_MAP[colorIndex % COLOR_MAP.length];
|
||||||
}
|
}
|
||||||
|
|
||||||
rgbaArray[4 * i] = color[0];
|
rgbaArray[4 * i] = color[0];
|
||||||
|
|
11
mediapipe/tasks/web/vision/core/types.d.ts
vendored
11
mediapipe/tasks/web/vision/core/types.d.ts
vendored
|
@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke
|
||||||
*/
|
*/
|
||||||
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
|
export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture;
|
||||||
|
|
||||||
/**
|
|
||||||
* A callback that receives the computed masks from the segmentation tasks. The
|
|
||||||
* callback either receives a single element array with a category mask (as a
|
|
||||||
* `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`).
|
|
||||||
* The returned data is only valid for the duration of the callback. If
|
|
||||||
* asynchronous processing is needed, all data needs to be copied before the
|
|
||||||
* callback returns.
|
|
||||||
*/
|
|
||||||
export type SegmentationMaskCallback =
|
|
||||||
(masks: SegmentationMask[], width: number, height: number) => void;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A callback that receives an `ImageData` object from a Vision task. The
|
* A callback that receives an `ImageData` object from a Vision task. The
|
||||||
* lifetime of the underlying data is limited to the duration of the callback.
|
* lifetime of the underlying data is limited to the duration of the callback.
|
||||||
|
|
|
@ -19,7 +19,7 @@ import {Connection} from '../../../../tasks/web/vision/core/types';
|
||||||
// tslint:disable:class-as-namespace Using for easier import by 3P users
|
// tslint:disable:class-as-namespace Using for easier import by 3P users
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A class containing the Pairs of landmark indices to be rendered with
|
* A class containing the pairs of landmark indices to be rendered with
|
||||||
* connections.
|
* connections.
|
||||||
*/
|
*/
|
||||||
export class FaceLandmarksConnections {
|
export class FaceLandmarksConnections {
|
||||||
|
|
|
@ -129,10 +129,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
* synchronously once the callback returns. Only use this method when the
|
* synchronously once the callback returns. Only use this method when the
|
||||||
* FaceStylizer is created with the image running mode.
|
* FaceStylizer is created with the image running mode.
|
||||||
*
|
*
|
||||||
* The input image can be of any size. To ensure that the output image has
|
|
||||||
* reasonable quality, the stylized output image size is determined by the
|
|
||||||
* model output size.
|
|
||||||
*
|
|
||||||
* @param image An image to process.
|
* @param image An image to process.
|
||||||
* @param callback The callback that is invoked with the stylized image. The
|
* @param callback The callback that is invoked with the stylized image. The
|
||||||
* lifetime of the returned data is only guaranteed for the duration of the
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
@ -153,11 +149,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
* If both are specified, the crop around the region-of-interest is extracted
|
* If both are specified, the crop around the region-of-interest is extracted
|
||||||
* first, then the specified rotation is applied to the crop.
|
* first, then the specified rotation is applied to the crop.
|
||||||
*
|
*
|
||||||
* The input image can be of any size. To ensure that the output image has
|
|
||||||
* reasonable quality, the stylized output image size is the smaller of the
|
|
||||||
* model output size and the size of the 'regionOfInterest' specified in
|
|
||||||
* 'imageProcessingOptions'.
|
|
||||||
*
|
|
||||||
* @param image An image to process.
|
* @param image An image to process.
|
||||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||||
* to process the input image before running inference.
|
* to process the input image before running inference.
|
||||||
|
@ -192,9 +183,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
* frame's timestamp (in milliseconds). The input timestamps must be
|
* frame's timestamp (in milliseconds). The input timestamps must be
|
||||||
* monotonically increasing.
|
* monotonically increasing.
|
||||||
*
|
*
|
||||||
* To ensure that the output image has reasonable quality, the stylized
|
|
||||||
* output image size is determined by the model output size.
|
|
||||||
*
|
|
||||||
* @param videoFrame A video frame to process.
|
* @param videoFrame A video frame to process.
|
||||||
* @param timestamp The timestamp of the current frame, in ms.
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
* @param callback The callback that is invoked with the stylized image. The
|
* @param callback The callback that is invoked with the stylized image. The
|
||||||
|
@ -221,10 +209,6 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
* frame's timestamp (in milliseconds). The input timestamps must be
|
* frame's timestamp (in milliseconds). The input timestamps must be
|
||||||
* monotonically increasing.
|
* monotonically increasing.
|
||||||
*
|
*
|
||||||
* To ensure that the output image has reasonable quality, the stylized
|
|
||||||
* output image size is the smaller of the model output size and the size of
|
|
||||||
* the 'regionOfInterest' specified in 'imageProcessingOptions'.
|
|
||||||
*
|
|
||||||
* @param videoFrame A video frame to process.
|
* @param videoFrame A video frame to process.
|
||||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||||
* to process the input image before running inference.
|
* to process the input image before running inference.
|
||||||
|
@ -278,8 +262,12 @@ export class FaceStylizer extends VisionTaskRunner {
|
||||||
|
|
||||||
this.graphRunner.attachImageListener(
|
this.graphRunner.attachImageListener(
|
||||||
STYLIZED_IMAGE_STREAM, (image, timestamp) => {
|
STYLIZED_IMAGE_STREAM, (image, timestamp) => {
|
||||||
const imageData = this.convertToImageData(image);
|
if (image.data instanceof WebGLTexture) {
|
||||||
this.userCallback(imageData, image.width, image.height);
|
this.userCallback(image.data, image.width, image.height);
|
||||||
|
} else {
|
||||||
|
const imageData = this.convertToImageData(image);
|
||||||
|
this.userCallback(imageData, image.width, image.height);
|
||||||
|
}
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
});
|
});
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
|
|
@ -34,6 +34,7 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||||
|
"//mediapipe/tasks/web/vision/hand_landmarker:hand_landmarks_connections",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -31,6 +31,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
|
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
@ -72,6 +73,12 @@ export class GestureRecognizer extends VisionTaskRunner {
|
||||||
private readonly handGestureRecognizerGraphOptions:
|
private readonly handGestureRecognizerGraphOptions:
|
||||||
HandGestureRecognizerGraphOptions;
|
HandGestureRecognizerGraphOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array containing the pairs of hand landmark indices to be rendered with
|
||||||
|
* connections.
|
||||||
|
*/
|
||||||
|
static HAND_CONNECTIONS = HAND_CONNECTIONS;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes the Wasm runtime and creates a new gesture recognizer from the
|
* Initializes the Wasm runtime and creates a new gesture recognizer from the
|
||||||
* provided options.
|
* provided options.
|
||||||
|
|
|
@ -16,6 +16,7 @@ mediapipe_ts_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":hand_landmarker_types",
|
":hand_landmarker_types",
|
||||||
|
":hand_landmarks_connections",
|
||||||
"//mediapipe/framework:calculator_jspb_proto",
|
"//mediapipe/framework:calculator_jspb_proto",
|
||||||
"//mediapipe/framework:calculator_options_jspb_proto",
|
"//mediapipe/framework:calculator_options_jspb_proto",
|
||||||
"//mediapipe/framework/formats:classification_jspb_proto",
|
"//mediapipe/framework/formats:classification_jspb_proto",
|
||||||
|
@ -72,3 +73,9 @@ jasmine_node_test(
|
||||||
tags = ["nomsan"],
|
tags = ["nomsan"],
|
||||||
deps = [":hand_landmarker_test_lib"],
|
deps = [":hand_landmarker_test_lib"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "hand_landmarks_connections",
|
||||||
|
srcs = ["hand_landmarks_connections.ts"],
|
||||||
|
deps = ["//mediapipe/tasks/web/vision/core:types"],
|
||||||
|
)
|
||||||
|
|
|
@ -27,6 +27,7 @@ import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/con
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
|
import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
|
@ -63,6 +64,12 @@ export class HandLandmarker extends VisionTaskRunner {
|
||||||
HandLandmarksDetectorGraphOptions;
|
HandLandmarksDetectorGraphOptions;
|
||||||
private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
|
private readonly handDetectorGraphOptions: HandDetectorGraphOptions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array containing the pairs of hand landmark indices to be rendered with
|
||||||
|
* connections.
|
||||||
|
*/
|
||||||
|
static HAND_CONNECTIONS = HAND_CONNECTIONS;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Initializes the Wasm runtime and creates a new `HandLandmarker` from the
|
* Initializes the Wasm runtime and creates a new `HandLandmarker` from the
|
||||||
* provided options.
|
* provided options.
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {Connection} from '../../../../tasks/web/vision/core/types';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An array containing the pairs of hand landmark indices to be rendered with
|
||||||
|
* connections.
|
||||||
|
*/
|
||||||
|
export const HAND_CONNECTIONS: Connection[] = [
|
||||||
|
{start: 0, end: 1}, {start: 1, end: 2}, {start: 2, end: 3},
|
||||||
|
{start: 3, end: 4}, {start: 0, end: 5}, {start: 5, end: 6},
|
||||||
|
{start: 6, end: 7}, {start: 7, end: 8}, {start: 5, end: 9},
|
||||||
|
{start: 9, end: 10}, {start: 10, end: 11}, {start: 11, end: 12},
|
||||||
|
{start: 9, end: 13}, {start: 13, end: 14}, {start: 14, end: 15},
|
||||||
|
{start: 15, end: 16}, {start: 13, end: 17}, {start: 0, end: 17},
|
||||||
|
{start: 17, end: 18}, {start: 18, end: 19}, {start: 19, end: 20}
|
||||||
|
];
|
|
@ -29,7 +29,10 @@ mediapipe_ts_library(
|
||||||
|
|
||||||
mediapipe_ts_declaration(
|
mediapipe_ts_declaration(
|
||||||
name = "image_segmenter_types",
|
name = "image_segmenter_types",
|
||||||
srcs = ["image_segmenter_options.d.ts"],
|
srcs = [
|
||||||
|
"image_segmenter_options.d.ts",
|
||||||
|
"image_segmenter_result.d.ts",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
|
|
@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
import {SegmentationMask} from '../../../../tasks/web/vision/core/types';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {LabelMapItem} from '../../../../util/label_map_pb';
|
import {LabelMapItem} from '../../../../util/label_map_pb';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||||
|
import {ImageSegmenterResult} from './image_segmenter_result';
|
||||||
|
|
||||||
export * from './image_segmenter_options';
|
export * from './image_segmenter_options';
|
||||||
export {SegmentationMask, SegmentationMaskCallback};
|
export * from './image_segmenter_result';
|
||||||
|
export {SegmentationMask};
|
||||||
export {ImageSource}; // Used in the public API
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
const IMAGE_STREAM = 'image_in';
|
const IMAGE_STREAM = 'image_in';
|
||||||
const NORM_RECT_STREAM = 'norm_rect';
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||||
|
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||||
const IMAGE_SEGMENTER_GRAPH =
|
const IMAGE_SEGMENTER_GRAPH =
|
||||||
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph';
|
||||||
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME =
|
||||||
'mediapipe.tasks.TensorsToSegmentationCalculator';
|
'mediapipe.tasks.TensorsToSegmentationCalculator';
|
||||||
|
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||||
|
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
// The OSS JS API does not support the builder pattern.
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A callback that receives the computed masks from the image segmenter. The
|
||||||
|
* returned data is only valid for the duration of the callback. If
|
||||||
|
* asynchronous processing is needed, all data needs to be copied before the
|
||||||
|
* callback returns.
|
||||||
|
*/
|
||||||
|
export type ImageSegmenterCallack = (result: ImageSegmenterResult) => void;
|
||||||
|
|
||||||
/** Performs image segmentation on images. */
|
/** Performs image segmentation on images. */
|
||||||
export class ImageSegmenter extends VisionTaskRunner {
|
export class ImageSegmenter extends VisionTaskRunner {
|
||||||
private userCallback: SegmentationMaskCallback = () => {};
|
private result: ImageSegmenterResult = {width: 0, height: 0};
|
||||||
private labels: string[] = [];
|
private labels: string[] = [];
|
||||||
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||||
|
|
||||||
|
@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
this.options.setBaseOptions(new BaseOptionsProto());
|
this.options.setBaseOptions(new BaseOptionsProto());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected override get baseOptions(): BaseOptionsProto {
|
protected override get baseOptions(): BaseOptionsProto {
|
||||||
return this.options.getBaseOptions()!;
|
return this.options.getBaseOptions()!;
|
||||||
}
|
}
|
||||||
|
@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
this.options.clearDisplayNamesLocale();
|
this.options.clearDisplayNamesLocale();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
if ('outputCategoryMask' in options) {
|
||||||
this.segmenterOptions.setOutputType(
|
this.outputCategoryMask =
|
||||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
} else {
|
}
|
||||||
this.segmenterOptions.setOutputType(
|
|
||||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
if ('outputConfidenceMasks' in options) {
|
||||||
|
this.outputConfidenceMasks =
|
||||||
|
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
}
|
}
|
||||||
|
|
||||||
return super.applyOptions(options);
|
return super.applyOptions(options);
|
||||||
|
@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
* lifetime of the returned data is only guaranteed for the duration of the
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
* callback.
|
* callback.
|
||||||
*/
|
*/
|
||||||
segment(image: ImageSource, callback: SegmentationMaskCallback): void;
|
segment(image: ImageSource, callback: ImageSegmenterCallack): void;
|
||||||
/**
|
/**
|
||||||
* Performs image segmentation on the provided single image and invokes the
|
* Performs image segmentation on the provided single image and invokes the
|
||||||
* callback with the response. The method returns synchronously once the
|
* callback with the response. The method returns synchronously once the
|
||||||
|
@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
*/
|
*/
|
||||||
segment(
|
segment(
|
||||||
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
image: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||||
callback: SegmentationMaskCallback): void;
|
callback: ImageSegmenterCallack): void;
|
||||||
segment(
|
segment(
|
||||||
image: ImageSource,
|
image: ImageSource,
|
||||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||||
SegmentationMaskCallback,
|
ImageSegmenterCallack,
|
||||||
callback?: SegmentationMaskCallback): void {
|
callback?: ImageSegmenterCallack): void {
|
||||||
const imageProcessingOptions =
|
const imageProcessingOptions =
|
||||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
{};
|
{};
|
||||||
|
const userCallback =
|
||||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
callback!;
|
callback!;
|
||||||
|
|
||||||
|
this.reset();
|
||||||
this.processImageData(image, imageProcessingOptions);
|
this.processImageData(image, imageProcessingOptions);
|
||||||
this.userCallback = () => {};
|
userCallback(this.result);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided video frame and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `video`.
|
||||||
|
*
|
||||||
|
* @param videoFrame A video frame to process.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource, timestamp: number,
|
||||||
|
callback: ImageSegmenterCallack): void;
|
||||||
|
/**
|
||||||
|
* Performs image segmentation on the provided video frame and invokes the
|
||||||
|
* callback with the response. The method returns synchronously once the
|
||||||
|
* callback returns. Only use this method when the ImageSegmenter is
|
||||||
|
* created with running mode `video`.
|
||||||
|
*
|
||||||
|
* @param videoFrame A video frame to process.
|
||||||
|
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
||||||
|
* to process the input image before running inference.
|
||||||
|
* @param timestamp The timestamp of the current frame, in ms.
|
||||||
|
* @param callback The callback that is invoked with the segmented masks. The
|
||||||
|
* lifetime of the returned data is only guaranteed for the duration of the
|
||||||
|
* callback.
|
||||||
|
*/
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
||||||
|
timestamp: number, callback: ImageSegmenterCallack): void;
|
||||||
|
segmentForVideo(
|
||||||
|
videoFrame: ImageSource,
|
||||||
|
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
||||||
|
timestampOrCallback: number|ImageSegmenterCallack,
|
||||||
|
callback?: ImageSegmenterCallack): void {
|
||||||
|
const imageProcessingOptions =
|
||||||
|
typeof timestampOrImageProcessingOptions !== 'number' ?
|
||||||
|
timestampOrImageProcessingOptions :
|
||||||
|
{};
|
||||||
|
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
||||||
|
timestampOrImageProcessingOptions :
|
||||||
|
timestampOrCallback as number;
|
||||||
|
const userCallback = typeof timestampOrCallback === 'function' ?
|
||||||
|
timestampOrCallback :
|
||||||
|
callback!;
|
||||||
|
|
||||||
|
this.reset();
|
||||||
|
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
||||||
|
userCallback(this.result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
return this.labels;
|
return this.labels;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private reset(): void {
|
||||||
* Performs image segmentation on the provided video frame and invokes the
|
this.result = {width: 0, height: 0};
|
||||||
* callback with the response. The method returns synchronously once the
|
|
||||||
* callback returns. Only use this method when the ImageSegmenter is
|
|
||||||
* created with running mode `video`.
|
|
||||||
*
|
|
||||||
* @param videoFrame A video frame to process.
|
|
||||||
* @param timestamp The timestamp of the current frame, in ms.
|
|
||||||
* @param callback The callback that is invoked with the segmented masks. The
|
|
||||||
* lifetime of the returned data is only guaranteed for the duration of the
|
|
||||||
* callback.
|
|
||||||
*/
|
|
||||||
segmentForVideo(
|
|
||||||
videoFrame: ImageSource, timestamp: number,
|
|
||||||
callback: SegmentationMaskCallback): void;
|
|
||||||
/**
|
|
||||||
* Performs image segmentation on the provided video frame and invokes the
|
|
||||||
* callback with the response. The method returns synchronously once the
|
|
||||||
* callback returns. Only use this method when the ImageSegmenter is
|
|
||||||
* created with running mode `video`.
|
|
||||||
*
|
|
||||||
* @param videoFrame A video frame to process.
|
|
||||||
* @param imageProcessingOptions the `ImageProcessingOptions` specifying how
|
|
||||||
* to process the input image before running inference.
|
|
||||||
* @param timestamp The timestamp of the current frame, in ms.
|
|
||||||
* @param callback The callback that is invoked with the segmented masks. The
|
|
||||||
* lifetime of the returned data is only guaranteed for the duration of the
|
|
||||||
* callback.
|
|
||||||
*/
|
|
||||||
segmentForVideo(
|
|
||||||
videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions,
|
|
||||||
timestamp: number, callback: SegmentationMaskCallback): void;
|
|
||||||
segmentForVideo(
|
|
||||||
videoFrame: ImageSource,
|
|
||||||
timestampOrImageProcessingOptions: number|ImageProcessingOptions,
|
|
||||||
timestampOrCallback: number|SegmentationMaskCallback,
|
|
||||||
callback?: SegmentationMaskCallback): void {
|
|
||||||
const imageProcessingOptions =
|
|
||||||
typeof timestampOrImageProcessingOptions !== 'number' ?
|
|
||||||
timestampOrImageProcessingOptions :
|
|
||||||
{};
|
|
||||||
const timestamp = typeof timestampOrImageProcessingOptions === 'number' ?
|
|
||||||
timestampOrImageProcessingOptions :
|
|
||||||
timestampOrCallback as number;
|
|
||||||
|
|
||||||
this.userCallback = typeof timestampOrCallback === 'function' ?
|
|
||||||
timestampOrCallback :
|
|
||||||
callback!;
|
|
||||||
this.processVideoData(videoFrame, imageProcessingOptions, timestamp);
|
|
||||||
this.userCallback = () => {};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
const graphConfig = new CalculatorGraphConfig();
|
const graphConfig = new CalculatorGraphConfig();
|
||||||
graphConfig.addInputStream(IMAGE_STREAM);
|
graphConfig.addInputStream(IMAGE_STREAM);
|
||||||
graphConfig.addInputStream(NORM_RECT_STREAM);
|
graphConfig.addInputStream(NORM_RECT_STREAM);
|
||||||
graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM);
|
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner {
|
||||||
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
|
segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH);
|
||||||
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM);
|
||||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM);
|
||||||
segmenterNode.addOutputStream(
|
|
||||||
'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM);
|
|
||||||
segmenterNode.setOptions(calculatorOptions);
|
segmenterNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(segmenterNode);
|
graphConfig.addNode(segmenterNode);
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
if (this.outputConfidenceMasks) {
|
||||||
GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => {
|
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||||
if (masks.length === 0) {
|
segmenterNode.addOutputStream(
|
||||||
this.userCallback([], 0, 0);
|
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||||
} else {
|
|
||||||
this.userCallback(
|
this.graphRunner.attachImageVectorListener(
|
||||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
}
|
this.result.confidenceMasks = masks.map(m => m.data);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
if (masks.length >= 0) {
|
||||||
});
|
this.result.width = masks[0].width;
|
||||||
this.graphRunner.attachEmptyPacketListener(
|
this.result.height = masks[0].height;
|
||||||
GROUPED_SEGMENTATIONS_STREAM, timestamp => {
|
}
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
|
||||||
});
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.outputCategoryMask) {
|
||||||
|
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||||
|
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||||
|
|
||||||
|
this.graphRunner.attachImageListener(
|
||||||
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
|
this.result.categoryMask = mask.data;
|
||||||
|
this.result.width = mask.width;
|
||||||
|
this.result.height = mask.height;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
CATEGORY_MASK_STREAM, timestamp => {
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
|
|
@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions {
|
||||||
*/
|
*/
|
||||||
displayNamesLocale?: string|undefined;
|
displayNamesLocale?: string|undefined;
|
||||||
|
|
||||||
/**
|
/** Whether to output confidence masks. Defaults to true. */
|
||||||
* The output type of segmentation results.
|
outputConfidenceMasks?: boolean|undefined;
|
||||||
*
|
|
||||||
* The two supported modes are:
|
/** Whether to output the category masks. Defaults to false. */
|
||||||
* - Category Mask: Gives a single output mask where each pixel represents
|
outputCategoryMask?: boolean|undefined;
|
||||||
* the class which the pixel in the original image was
|
|
||||||
* predicted to belong to.
|
|
||||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
|
||||||
* each mask, the pixel represents the prediction
|
|
||||||
* confidence, usually in the [0.0, 0.1] range.
|
|
||||||
*
|
|
||||||
* Defaults to `CATEGORY_MASK`.
|
|
||||||
*/
|
|
||||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
|
||||||
}
|
}
|
||||||
|
|
37
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
vendored
Normal file
37
mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts
vendored
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** The output result of ImageSegmenter. */
|
||||||
|
export declare interface ImageSegmenterResult {
|
||||||
|
/**
|
||||||
|
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
|
||||||
|
* pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||||
|
*/
|
||||||
|
confidenceMasks?: Float32Array[]|WebGLTexture[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A category mask as a Uint8ClampedArray or WebGLTexture where each
|
||||||
|
* pixel represents the class which the pixel in the original image was
|
||||||
|
* predicted to belong to.
|
||||||
|
*/
|
||||||
|
categoryMask?: Uint8ClampedArray|WebGLTexture;
|
||||||
|
|
||||||
|
/** The width of the masks. */
|
||||||
|
width: number;
|
||||||
|
|
||||||
|
/** The height of the masks. */
|
||||||
|
height: number;
|
||||||
|
}
|
|
@ -18,7 +18,7 @@ import 'jasmine';
|
||||||
|
|
||||||
// Placeholder for internal dependency on encodeByteArray
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||||
|
|
||||||
import {ImageSegmenter} from './image_segmenter';
|
import {ImageSegmenter} from './image_segmenter';
|
||||||
|
@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
graph: CalculatorGraphConfig|undefined;
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
|
||||||
fakeWasmModule: SpyWasmModule;
|
fakeWasmModule: SpyWasmModule;
|
||||||
imageVectorListener:
|
categoryMaskListener:
|
||||||
|
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||||
|
confidenceMasksListener:
|
||||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
|
@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake {
|
||||||
this.fakeWasmModule =
|
this.fakeWasmModule =
|
||||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
this.attachListenerSpies[0] =
|
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('category_mask');
|
||||||
|
this.categoryMaskListener = listener;
|
||||||
|
});
|
||||||
|
this.attachListenerSpies[1] =
|
||||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||||
.and.callFake((stream, listener) => {
|
.and.callFake((stream, listener) => {
|
||||||
expect(stream).toEqual('segmented_masks');
|
expect(stream).toEqual('confidence_masks');
|
||||||
this.imageVectorListener = listener;
|
this.confidenceMasksListener = listener;
|
||||||
});
|
});
|
||||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
@ -63,17 +70,18 @@ describe('ImageSegmenter', () => {
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
verifyGraph(imageSegmenter);
|
verifyGraph(imageSegmenter);
|
||||||
verifyListenersRegistered(imageSegmenter);
|
|
||||||
|
// Verify default options
|
||||||
|
expect(imageSegmenter.categoryMaskListener).not.toBeDefined();
|
||||||
|
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('reloads graph when settings are changed', async () => {
|
it('reloads graph when settings are changed', async () => {
|
||||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||||
verifyListenersRegistered(imageSegmenter);
|
|
||||||
|
|
||||||
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
|
await imageSegmenter.setOptions({displayNamesLocale: 'de'});
|
||||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']);
|
||||||
verifyListenersRegistered(imageSegmenter);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can use custom models', async () => {
|
it('can use custom models', async () => {
|
||||||
|
@ -100,9 +108,11 @@ describe('ImageSegmenter', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
it('merges options', async () => {
|
it('merges options', async () => {
|
||||||
await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
await imageSegmenter.setOptions(
|
||||||
|
{baseOptions: {modelAssetBuffer: new Uint8Array([])}});
|
||||||
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
await imageSegmenter.setOptions({displayNamesLocale: 'en'});
|
||||||
verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
verifyGraph(
|
||||||
|
imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']);
|
||||||
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -115,22 +125,13 @@ describe('ImageSegmenter', () => {
|
||||||
defaultValue: unknown;
|
defaultValue: unknown;
|
||||||
}
|
}
|
||||||
|
|
||||||
const testCases: TestCase[] = [
|
const testCases: TestCase[] = [{
|
||||||
{
|
optionName: 'displayNamesLocale',
|
||||||
optionName: 'displayNamesLocale',
|
fieldPath: ['displayNamesLocale'],
|
||||||
fieldPath: ['displayNamesLocale'],
|
userValue: 'en',
|
||||||
userValue: 'en',
|
graphValue: 'en',
|
||||||
graphValue: 'en',
|
defaultValue: 'en'
|
||||||
defaultValue: 'en'
|
}];
|
||||||
},
|
|
||||||
{
|
|
||||||
optionName: 'outputType',
|
|
||||||
fieldPath: ['segmenterOptions', 'outputType'],
|
|
||||||
userValue: 'CONFIDENCE_MASK',
|
|
||||||
graphValue: 2,
|
|
||||||
defaultValue: 1
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
for (const testCase of testCases) {
|
for (const testCase of testCases) {
|
||||||
it(`can set ${testCase.optionName}`, async () => {
|
it(`can set ${testCase.optionName}`, async () => {
|
||||||
|
@ -158,27 +159,31 @@ describe('ImageSegmenter', () => {
|
||||||
}).toThrowError('This task doesn\'t support region-of-interest.');
|
}).toThrowError('This task doesn\'t support region-of-interest.');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('supports category masks', (done) => {
|
it('supports category mask', async () => {
|
||||||
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
||||||
|
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: true, outputConfidenceMasks: false});
|
||||||
|
|
||||||
// Pass the test data to our listener
|
// Pass the test data to our listener
|
||||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
verifyListenersRegistered(imageSegmenter);
|
expect(imageSegmenter.categoryMaskListener).toBeDefined();
|
||||||
imageSegmenter.imageVectorListener!(
|
imageSegmenter.categoryMaskListener!
|
||||||
[
|
({data: mask, width: 2, height: 2},
|
||||||
{data: mask, width: 2, height: 2},
|
/* timestamp= */ 1337);
|
||||||
],
|
|
||||||
/* timestamp= */ 1337);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Invoke the image segmenter
|
// Invoke the image segmenter
|
||||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
|
||||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
return new Promise<void>(resolve => {
|
||||||
expect(masks).toHaveSize(1);
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
expect(masks[0]).toEqual(mask);
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(width).toEqual(2);
|
expect(result.categoryMask).toEqual(mask);
|
||||||
expect(height).toEqual(2);
|
expect(result.confidenceMasks).not.toBeDefined();
|
||||||
done();
|
expect(result.width).toEqual(2);
|
||||||
|
expect(result.height).toEqual(2);
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -186,12 +191,13 @@ describe('ImageSegmenter', () => {
|
||||||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||||
|
|
||||||
await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
await imageSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: false, outputConfidenceMasks: true});
|
||||||
|
|
||||||
// Pass the test data to our listener
|
// Pass the test data to our listener
|
||||||
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
verifyListenersRegistered(imageSegmenter);
|
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
imageSegmenter.imageVectorListener!(
|
imageSegmenter.confidenceMasksListener!(
|
||||||
[
|
[
|
||||||
{data: mask1, width: 2, height: 2},
|
{data: mask1, width: 2, height: 2},
|
||||||
{data: mask2, width: 2, height: 2},
|
{data: mask2, width: 2, height: 2},
|
||||||
|
@ -201,13 +207,49 @@ describe('ImageSegmenter', () => {
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
// Invoke the image segmenter
|
// Invoke the image segmenter
|
||||||
imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => {
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
expect(masks).toHaveSize(2);
|
expect(result.categoryMask).not.toBeDefined();
|
||||||
expect(masks[0]).toEqual(mask1);
|
expect(result.confidenceMasks).toEqual([mask1, mask2]);
|
||||||
expect(masks[1]).toEqual(mask2);
|
expect(result.width).toEqual(2);
|
||||||
expect(width).toEqual(2);
|
expect(result.height).toEqual(2);
|
||||||
expect(height).toEqual(2);
|
resolve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('supports combined category and confidence masks', async () => {
|
||||||
|
const categoryMask = new Uint8ClampedArray([1, 0]);
|
||||||
|
const confidenceMask1 = new Float32Array([0.0, 1.0]);
|
||||||
|
const confidenceMask2 = new Float32Array([1.0, 0.0]);
|
||||||
|
|
||||||
|
await imageSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
expect(imageSegmenter.categoryMaskListener).toBeDefined();
|
||||||
|
expect(imageSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
|
imageSegmenter.categoryMaskListener!
|
||||||
|
({data: categoryMask, width: 1, height: 1}, 1337);
|
||||||
|
imageSegmenter.confidenceMasksListener!(
|
||||||
|
[
|
||||||
|
{data: confidenceMask1, width: 1, height: 1},
|
||||||
|
{data: confidenceMask2, width: 1, height: 1},
|
||||||
|
],
|
||||||
|
1337);
|
||||||
|
});
|
||||||
|
|
||||||
|
return new Promise<void>(resolve => {
|
||||||
|
// Invoke the image segmenter
|
||||||
|
imageSegmenter.segment({} as HTMLImageElement, result => {
|
||||||
|
expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled();
|
||||||
|
expect(result.categoryMask).toEqual(categoryMask);
|
||||||
|
expect(result.confidenceMasks).toEqual([
|
||||||
|
confidenceMask1, confidenceMask2
|
||||||
|
]);
|
||||||
|
expect(result.width).toEqual(1);
|
||||||
|
expect(result.height).toEqual(1);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -30,7 +30,10 @@ mediapipe_ts_library(
|
||||||
|
|
||||||
mediapipe_ts_declaration(
|
mediapipe_ts_declaration(
|
||||||
name = "interactive_segmenter_types",
|
name = "interactive_segmenter_types",
|
||||||
srcs = ["interactive_segmenter_options.d.ts"],
|
srcs = [
|
||||||
|
"interactive_segmenter_options.d.ts",
|
||||||
|
"interactive_segmenter_result.d.ts",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/core:classifier_options",
|
"//mediapipe/tasks/web/core:classifier_options",
|
||||||
|
|
|
@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {Color as ColorProto} from '../../../../util/color_pb';
|
import {Color as ColorProto} from '../../../../util/color_pb';
|
||||||
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||||
|
@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
|
||||||
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
|
import {InteractiveSegmenterOptions} from './interactive_segmenter_options';
|
||||||
|
import {InteractiveSegmenterResult} from './interactive_segmenter_result';
|
||||||
|
|
||||||
export * from './interactive_segmenter_options';
|
export * from './interactive_segmenter_options';
|
||||||
export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest};
|
export * from './interactive_segmenter_result';
|
||||||
|
export {SegmentationMask, RegionOfInterest};
|
||||||
export {ImageSource};
|
export {ImageSource};
|
||||||
|
|
||||||
const IMAGE_IN_STREAM = 'image_in';
|
const IMAGE_IN_STREAM = 'image_in';
|
||||||
const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
const NORM_RECT_IN_STREAM = 'norm_rect_in';
|
||||||
const ROI_IN_STREAM = 'roi_in';
|
const ROI_IN_STREAM = 'roi_in';
|
||||||
const IMAGE_OUT_STREAM = 'image_out';
|
const CONFIDENCE_MASKS_STREAM = 'confidence_masks';
|
||||||
|
const CATEGORY_MASK_STREAM = 'category_mask';
|
||||||
const IMAGEA_SEGMENTER_GRAPH =
|
const IMAGEA_SEGMENTER_GRAPH =
|
||||||
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph';
|
||||||
|
const DEFAULT_OUTPUT_CATEGORY_MASK = false;
|
||||||
|
const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true;
|
||||||
|
|
||||||
// The OSS JS API does not support the builder pattern.
|
// The OSS JS API does not support the builder pattern.
|
||||||
// tslint:disable:jspb-use-builder-pattern
|
// tslint:disable:jspb-use-builder-pattern
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A callback that receives the computed masks from the interactive segmenter.
|
||||||
|
* The returned data is only valid for the duration of the callback. If
|
||||||
|
* asynchronous processing is needed, all data needs to be copied before the
|
||||||
|
* callback returns.
|
||||||
|
*/
|
||||||
|
export type InteractiveSegmenterCallack =
|
||||||
|
(result: InteractiveSegmenterResult) => void;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Performs interactive segmentation on images.
|
* Performs interactive segmentation on images.
|
||||||
*
|
*
|
||||||
|
@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH =
|
||||||
* - batch is always 1
|
* - batch is always 1
|
||||||
*/
|
*/
|
||||||
export class InteractiveSegmenter extends VisionTaskRunner {
|
export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
private userCallback: SegmentationMaskCallback = () => {};
|
private result: InteractiveSegmenterResult = {width: 0, height: 0};
|
||||||
|
private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
|
private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
private readonly options: ImageSegmenterGraphOptionsProto;
|
private readonly options: ImageSegmenterGraphOptionsProto;
|
||||||
private readonly segmenterOptions: SegmenterOptionsProto;
|
private readonly segmenterOptions: SegmenterOptionsProto;
|
||||||
|
|
||||||
|
@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
* @return A Promise that resolves when the settings have been applied.
|
* @return A Promise that resolves when the settings have been applied.
|
||||||
*/
|
*/
|
||||||
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
|
override setOptions(options: InteractiveSegmenterOptions): Promise<void> {
|
||||||
if (options.outputType === 'CONFIDENCE_MASK') {
|
if ('outputCategoryMask' in options) {
|
||||||
this.segmenterOptions.setOutputType(
|
this.outputCategoryMask =
|
||||||
SegmenterOptionsProto.OutputType.CONFIDENCE_MASK);
|
options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK;
|
||||||
} else {
|
}
|
||||||
this.segmenterOptions.setOutputType(
|
|
||||||
SegmenterOptionsProto.OutputType.CATEGORY_MASK);
|
if ('outputConfidenceMasks' in options) {
|
||||||
|
this.outputConfidenceMasks =
|
||||||
|
options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS;
|
||||||
}
|
}
|
||||||
|
|
||||||
return super.applyOptions(options);
|
return super.applyOptions(options);
|
||||||
|
@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
*/
|
*/
|
||||||
segment(
|
segment(
|
||||||
image: ImageSource, roi: RegionOfInterest,
|
image: ImageSource, roi: RegionOfInterest,
|
||||||
callback: SegmentationMaskCallback): void;
|
callback: InteractiveSegmenterCallack): void;
|
||||||
/**
|
/**
|
||||||
* Performs interactive segmentation on the provided single image and invokes
|
* Performs interactive segmentation on the provided single image and invokes
|
||||||
* the callback with the response. The `roi` parameter is used to represent a
|
* the callback with the response. The `roi` parameter is used to represent a
|
||||||
|
@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
segment(
|
segment(
|
||||||
image: ImageSource, roi: RegionOfInterest,
|
image: ImageSource, roi: RegionOfInterest,
|
||||||
imageProcessingOptions: ImageProcessingOptions,
|
imageProcessingOptions: ImageProcessingOptions,
|
||||||
callback: SegmentationMaskCallback): void;
|
callback: InteractiveSegmenterCallack): void;
|
||||||
segment(
|
segment(
|
||||||
image: ImageSource, roi: RegionOfInterest,
|
image: ImageSource, roi: RegionOfInterest,
|
||||||
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
imageProcessingOptionsOrCallback: ImageProcessingOptions|
|
||||||
SegmentationMaskCallback,
|
InteractiveSegmenterCallack,
|
||||||
callback?: SegmentationMaskCallback): void {
|
callback?: InteractiveSegmenterCallack): void {
|
||||||
const imageProcessingOptions =
|
const imageProcessingOptions =
|
||||||
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
typeof imageProcessingOptionsOrCallback !== 'function' ?
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
{};
|
{};
|
||||||
|
const userCallback =
|
||||||
this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ?
|
typeof imageProcessingOptionsOrCallback === 'function' ?
|
||||||
imageProcessingOptionsOrCallback :
|
imageProcessingOptionsOrCallback :
|
||||||
callback!;
|
callback!;
|
||||||
|
|
||||||
|
this.reset();
|
||||||
this.processRenderData(roi, this.getSynctheticTimestamp());
|
this.processRenderData(roi, this.getSynctheticTimestamp());
|
||||||
this.processImageData(image, imageProcessingOptions);
|
this.processImageData(image, imageProcessingOptions);
|
||||||
this.userCallback = () => {};
|
userCallback(this.result);
|
||||||
|
}
|
||||||
|
|
||||||
|
private reset(): void {
|
||||||
|
this.result = {width: 0, height: 0};
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Updates the MediaPipe graph configuration. */
|
/** Updates the MediaPipe graph configuration. */
|
||||||
|
@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
graphConfig.addInputStream(IMAGE_IN_STREAM);
|
graphConfig.addInputStream(IMAGE_IN_STREAM);
|
||||||
graphConfig.addInputStream(ROI_IN_STREAM);
|
graphConfig.addInputStream(ROI_IN_STREAM);
|
||||||
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
|
graphConfig.addInputStream(NORM_RECT_IN_STREAM);
|
||||||
graphConfig.addOutputStream(IMAGE_OUT_STREAM);
|
|
||||||
|
|
||||||
const calculatorOptions = new CalculatorOptions();
|
const calculatorOptions = new CalculatorOptions();
|
||||||
calculatorOptions.setExtension(
|
calculatorOptions.setExtension(
|
||||||
|
@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner {
|
||||||
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
|
segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM);
|
||||||
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
|
segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM);
|
||||||
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
|
segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM);
|
||||||
segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM);
|
|
||||||
segmenterNode.setOptions(calculatorOptions);
|
segmenterNode.setOptions(calculatorOptions);
|
||||||
|
|
||||||
graphConfig.addNode(segmenterNode);
|
graphConfig.addNode(segmenterNode);
|
||||||
|
|
||||||
this.graphRunner.attachImageVectorListener(
|
if (this.outputConfidenceMasks) {
|
||||||
IMAGE_OUT_STREAM, (masks, timestamp) => {
|
graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM);
|
||||||
if (masks.length === 0) {
|
segmenterNode.addOutputStream(
|
||||||
this.userCallback([], 0, 0);
|
'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM);
|
||||||
} else {
|
|
||||||
this.userCallback(
|
this.graphRunner.attachImageVectorListener(
|
||||||
masks.map(m => m.data), masks[0].width, masks[0].height);
|
CONFIDENCE_MASKS_STREAM, (masks, timestamp) => {
|
||||||
}
|
this.result.confidenceMasks = masks.map(m => m.data);
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
if (masks.length >= 0) {
|
||||||
});
|
this.result.width = masks[0].width;
|
||||||
this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => {
|
this.result.height = masks[0].height;
|
||||||
this.setLatestOutputTimestamp(timestamp);
|
}
|
||||||
});
|
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
CONFIDENCE_MASKS_STREAM, timestamp => {
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.outputCategoryMask) {
|
||||||
|
graphConfig.addOutputStream(CATEGORY_MASK_STREAM);
|
||||||
|
segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM);
|
||||||
|
|
||||||
|
this.graphRunner.attachImageListener(
|
||||||
|
CATEGORY_MASK_STREAM, (mask, timestamp) => {
|
||||||
|
this.result.categoryMask = mask.data;
|
||||||
|
this.result.width = mask.width;
|
||||||
|
this.result.height = mask.height;
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
this.graphRunner.attachEmptyPacketListener(
|
||||||
|
CATEGORY_MASK_STREAM, timestamp => {
|
||||||
|
this.setLatestOutputTimestamp(timestamp);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const binaryGraph = graphConfig.serializeBinary();
|
const binaryGraph = graphConfig.serializeBinary();
|
||||||
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true);
|
||||||
|
|
|
@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'
|
||||||
|
|
||||||
/** Options to configure the MediaPipe Interactive Segmenter Task */
|
/** Options to configure the MediaPipe Interactive Segmenter Task */
|
||||||
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
|
export interface InteractiveSegmenterOptions extends TaskRunnerOptions {
|
||||||
/**
|
/** Whether to output confidence masks. Defaults to true. */
|
||||||
* The output type of segmentation results.
|
outputConfidenceMasks?: boolean|undefined;
|
||||||
*
|
|
||||||
* The two supported modes are:
|
/** Whether to output the category masks. Defaults to false. */
|
||||||
* - Category Mask: Gives a single output mask where each pixel represents
|
outputCategoryMask?: boolean|undefined;
|
||||||
* the class which the pixel in the original image was
|
|
||||||
* predicted to belong to.
|
|
||||||
* - Confidence Mask: Gives a list of output masks (one for each class). For
|
|
||||||
* each mask, the pixel represents the prediction
|
|
||||||
* confidence, usually in the [0.0, 0.1] range.
|
|
||||||
*
|
|
||||||
* Defaults to `CATEGORY_MASK`.
|
|
||||||
*/
|
|
||||||
outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined;
|
|
||||||
}
|
}
|
||||||
|
|
37
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts
vendored
Normal file
37
mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts
vendored
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/** The output result of InteractiveSegmenter. */
|
||||||
|
export declare interface InteractiveSegmenterResult {
|
||||||
|
/**
|
||||||
|
* Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each
|
||||||
|
* pixel represents the prediction confidence, usually in the [0, 1] range.
|
||||||
|
*/
|
||||||
|
confidenceMasks?: Float32Array[]|WebGLTexture[];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A category mask as a Uint8ClampedArray or WebGLTexture where each
|
||||||
|
* pixel represents the class which the pixel in the original image was
|
||||||
|
* predicted to belong to.
|
||||||
|
*/
|
||||||
|
categoryMask?: Uint8ClampedArray|WebGLTexture;
|
||||||
|
|
||||||
|
/** The width of the masks. */
|
||||||
|
width: number;
|
||||||
|
|
||||||
|
/** The height of the masks. */
|
||||||
|
height: number;
|
||||||
|
}
|
|
@ -18,7 +18,7 @@ import 'jasmine';
|
||||||
|
|
||||||
// Placeholder for internal dependency on encodeByteArray
|
// Placeholder for internal dependency on encodeByteArray
|
||||||
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
import {CalculatorGraphConfig} from '../../../../framework/calculator_pb';
|
||||||
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils';
|
import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils';
|
||||||
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
import {RenderData as RenderDataProto} from '../../../../util/render_data_pb';
|
||||||
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib';
|
||||||
|
|
||||||
|
@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||||
graph: CalculatorGraphConfig|undefined;
|
graph: CalculatorGraphConfig|undefined;
|
||||||
|
|
||||||
fakeWasmModule: SpyWasmModule;
|
fakeWasmModule: SpyWasmModule;
|
||||||
imageVectorListener:
|
categoryMaskListener:
|
||||||
|
((images: WasmImage, timestamp: number) => void)|undefined;
|
||||||
|
confidenceMasksListener:
|
||||||
((images: WasmImage[], timestamp: number) => void)|undefined;
|
((images: WasmImage[], timestamp: number) => void)|undefined;
|
||||||
lastRoi?: RenderDataProto;
|
lastRoi?: RenderDataProto;
|
||||||
|
|
||||||
|
@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements
|
||||||
this.fakeWasmModule =
|
this.fakeWasmModule =
|
||||||
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
this.graphRunner.wasmModule as unknown as SpyWasmModule;
|
||||||
|
|
||||||
this.attachListenerSpies[0] =
|
this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener')
|
||||||
|
.and.callFake((stream, listener) => {
|
||||||
|
expect(stream).toEqual('category_mask');
|
||||||
|
this.categoryMaskListener = listener;
|
||||||
|
});
|
||||||
|
this.attachListenerSpies[1] =
|
||||||
spyOn(this.graphRunner, 'attachImageVectorListener')
|
spyOn(this.graphRunner, 'attachImageVectorListener')
|
||||||
.and.callFake((stream, listener) => {
|
.and.callFake((stream, listener) => {
|
||||||
expect(stream).toEqual('image_out');
|
expect(stream).toEqual('confidence_masks');
|
||||||
this.imageVectorListener = listener;
|
this.confidenceMasksListener = listener;
|
||||||
});
|
});
|
||||||
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => {
|
||||||
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph);
|
||||||
|
@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => {
|
||||||
|
|
||||||
it('initializes graph', async () => {
|
it('initializes graph', async () => {
|
||||||
verifyGraph(interactiveSegmenter);
|
verifyGraph(interactiveSegmenter);
|
||||||
verifyListenersRegistered(interactiveSegmenter);
|
|
||||||
|
// Verify default options
|
||||||
|
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
|
||||||
|
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('reloads graph when settings are changed', async () => {
|
it('reloads graph when settings are changed', async () => {
|
||||||
await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'});
|
await interactiveSegmenter.setOptions(
|
||||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]);
|
{outputConfidenceMasks: true, outputCategoryMask: false});
|
||||||
verifyListenersRegistered(interactiveSegmenter);
|
expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined();
|
||||||
|
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
|
|
||||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
await interactiveSegmenter.setOptions(
|
||||||
verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]);
|
{outputConfidenceMasks: false, outputCategoryMask: true});
|
||||||
verifyListenersRegistered(interactiveSegmenter);
|
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('can use custom models', async () => {
|
it('can use custom models', async () => {
|
||||||
|
@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => {
|
||||||
]);
|
]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
||||||
describe('setOptions()', () => {
|
|
||||||
const fieldPath = ['segmenterOptions', 'outputType'];
|
|
||||||
|
|
||||||
it(`can set outputType`, async () => {
|
|
||||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
|
||||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
|
||||||
});
|
|
||||||
|
|
||||||
it(`can clear outputType`, async () => {
|
|
||||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
|
||||||
verifyGraph(interactiveSegmenter, [fieldPath, 2]);
|
|
||||||
await interactiveSegmenter.setOptions({outputType: undefined});
|
|
||||||
verifyGraph(interactiveSegmenter, [fieldPath, 1]);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('doesn\'t support region of interest', () => {
|
it('doesn\'t support region of interest', () => {
|
||||||
expect(() => {
|
expect(() => {
|
||||||
interactiveSegmenter.segment(
|
interactiveSegmenter.segment(
|
||||||
|
@ -153,60 +147,99 @@ describe('InteractiveSegmenter', () => {
|
||||||
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
|
interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('supports category masks', (done) => {
|
it('supports category mask', async () => {
|
||||||
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
const mask = new Uint8ClampedArray([1, 2, 3, 4]);
|
||||||
|
|
||||||
|
await interactiveSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: true, outputConfidenceMasks: false});
|
||||||
|
|
||||||
// Pass the test data to our listener
|
// Pass the test data to our listener
|
||||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
verifyListenersRegistered(interactiveSegmenter);
|
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||||
interactiveSegmenter.imageVectorListener!(
|
interactiveSegmenter.categoryMaskListener!
|
||||||
[
|
({data: mask, width: 2, height: 2},
|
||||||
{data: mask, width: 2, height: 2},
|
/* timestamp= */ 1337);
|
||||||
],
|
|
||||||
/* timestamp= */ 1337);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Invoke the image segmenter
|
// Invoke the image segmenter
|
||||||
interactiveSegmenter.segment(
|
return new Promise<void>(resolve => {
|
||||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
|
||||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||||
.toHaveBeenCalled();
|
.toHaveBeenCalled();
|
||||||
expect(masks).toHaveSize(1);
|
expect(result.categoryMask).toEqual(mask);
|
||||||
expect(masks[0]).toEqual(mask);
|
expect(result.confidenceMasks).not.toBeDefined();
|
||||||
expect(width).toEqual(2);
|
expect(result.width).toEqual(2);
|
||||||
expect(height).toEqual(2);
|
expect(result.height).toEqual(2);
|
||||||
done();
|
resolve();
|
||||||
});
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('supports confidence masks', async () => {
|
it('supports confidence masks', async () => {
|
||||||
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]);
|
||||||
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]);
|
||||||
|
|
||||||
await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'});
|
await interactiveSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: false, outputConfidenceMasks: true});
|
||||||
|
|
||||||
// Pass the test data to our listener
|
// Pass the test data to our listener
|
||||||
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
verifyListenersRegistered(interactiveSegmenter);
|
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
interactiveSegmenter.imageVectorListener!(
|
interactiveSegmenter.confidenceMasksListener!(
|
||||||
[
|
[
|
||||||
{data: mask1, width: 2, height: 2},
|
{data: mask1, width: 2, height: 2},
|
||||||
{data: mask2, width: 2, height: 2},
|
{data: mask2, width: 2, height: 2},
|
||||||
],
|
],
|
||||||
1337);
|
1337);
|
||||||
});
|
});
|
||||||
|
return new Promise<void>(resolve => {
|
||||||
|
// Invoke the image segmenter
|
||||||
|
interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => {
|
||||||
|
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||||
|
.toHaveBeenCalled();
|
||||||
|
expect(result.categoryMask).not.toBeDefined();
|
||||||
|
expect(result.confidenceMasks).toEqual([mask1, mask2]);
|
||||||
|
expect(result.width).toEqual(2);
|
||||||
|
expect(result.height).toEqual(2);
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('supports combined category and confidence masks', async () => {
|
||||||
|
const categoryMask = new Uint8ClampedArray([1, 0]);
|
||||||
|
const confidenceMask1 = new Float32Array([0.0, 1.0]);
|
||||||
|
const confidenceMask2 = new Float32Array([1.0, 0.0]);
|
||||||
|
|
||||||
|
await interactiveSegmenter.setOptions(
|
||||||
|
{outputCategoryMask: true, outputConfidenceMasks: true});
|
||||||
|
|
||||||
|
// Pass the test data to our listener
|
||||||
|
interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => {
|
||||||
|
expect(interactiveSegmenter.categoryMaskListener).toBeDefined();
|
||||||
|
expect(interactiveSegmenter.confidenceMasksListener).toBeDefined();
|
||||||
|
interactiveSegmenter.categoryMaskListener!
|
||||||
|
({data: categoryMask, width: 1, height: 1}, 1337);
|
||||||
|
interactiveSegmenter.confidenceMasksListener!(
|
||||||
|
[
|
||||||
|
{data: confidenceMask1, width: 1, height: 1},
|
||||||
|
{data: confidenceMask2, width: 1, height: 1},
|
||||||
|
],
|
||||||
|
1337);
|
||||||
|
});
|
||||||
|
|
||||||
return new Promise<void>(resolve => {
|
return new Promise<void>(resolve => {
|
||||||
// Invoke the image segmenter
|
// Invoke the image segmenter
|
||||||
interactiveSegmenter.segment(
|
interactiveSegmenter.segment(
|
||||||
{} as HTMLImageElement, ROI, (masks, width, height) => {
|
{} as HTMLImageElement, ROI, result => {
|
||||||
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle)
|
||||||
.toHaveBeenCalled();
|
.toHaveBeenCalled();
|
||||||
expect(masks).toHaveSize(2);
|
expect(result.categoryMask).toEqual(categoryMask);
|
||||||
expect(masks[0]).toEqual(mask1);
|
expect(result.confidenceMasks).toEqual([
|
||||||
expect(masks[1]).toEqual(mask2);
|
confidenceMask1, confidenceMask2
|
||||||
expect(width).toEqual(2);
|
]);
|
||||||
expect(height).toEqual(2);
|
expect(result.width).toEqual(1);
|
||||||
|
expect(result.height).toEqual(1);
|
||||||
resolve();
|
resolve();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -56,8 +56,8 @@ bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y,
|
||||||
VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0";
|
VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0";
|
||||||
}
|
}
|
||||||
|
|
||||||
*x_px = static_cast<int32>(round(normalized_x * image_width));
|
*x_px = static_cast<int32_t>(round(normalized_x * image_width));
|
||||||
*y_px = static_cast<int32>(round(normalized_y * image_height));
|
*y_px = static_cast<int32_t>(round(normalized_y * image_height));
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ ABSL_FLAG(std::string, system_cpu_max_freq_file,
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
constexpr uint32 kBufferLength = 64;
|
constexpr uint32_t kBufferLength = 64;
|
||||||
|
|
||||||
absl::StatusOr<std::string> GetFilePath(int cpu) {
|
absl::StatusOr<std::string> GetFilePath(int cpu) {
|
||||||
if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
|
if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) {
|
||||||
|
@ -54,7 +54,7 @@ absl::StatusOr<std::string> GetFilePath(int cpu) {
|
||||||
return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu);
|
return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu);
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
|
absl::StatusOr<uint64_t> GetCpuMaxFrequency(int cpu) {
|
||||||
auto path_or_status = GetFilePath(cpu);
|
auto path_or_status = GetFilePath(cpu);
|
||||||
if (!path_or_status.ok()) {
|
if (!path_or_status.ok()) {
|
||||||
return path_or_status.status();
|
return path_or_status.status();
|
||||||
|
@ -65,7 +65,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
|
||||||
char buffer[kBufferLength];
|
char buffer[kBufferLength];
|
||||||
file.getline(buffer, kBufferLength);
|
file.getline(buffer, kBufferLength);
|
||||||
file.close();
|
file.close();
|
||||||
uint64 frequency;
|
uint64_t frequency;
|
||||||
if (absl::SimpleAtoi(buffer, &frequency)) {
|
if (absl::SimpleAtoi(buffer, &frequency)) {
|
||||||
return frequency;
|
return frequency;
|
||||||
} else {
|
} else {
|
||||||
|
@ -79,7 +79,7 @@ absl::StatusOr<uint64> GetCpuMaxFrequency(int cpu) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::set<int> InferLowerOrHigherCoreIds(bool lower) {
|
std::set<int> InferLowerOrHigherCoreIds(bool lower) {
|
||||||
std::vector<std::pair<int, uint64>> cpu_freq_pairs;
|
std::vector<std::pair<int, uint64_t>> cpu_freq_pairs;
|
||||||
for (int cpu = 0; cpu < NumCPUCores(); ++cpu) {
|
for (int cpu = 0; cpu < NumCPUCores(); ++cpu) {
|
||||||
auto freq_or_status = GetCpuMaxFrequency(cpu);
|
auto freq_or_status = GetCpuMaxFrequency(cpu);
|
||||||
if (freq_or_status.ok()) {
|
if (freq_or_status.ok()) {
|
||||||
|
@ -90,12 +90,12 @@ std::set<int> InferLowerOrHigherCoreIds(bool lower) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64>& left,
|
absl::c_sort(cpu_freq_pairs, [lower](const std::pair<int, uint64_t>& left,
|
||||||
const std::pair<int, uint64>& right) {
|
const std::pair<int, uint64_t>& right) {
|
||||||
return (lower && left.second < right.second) ||
|
return (lower && left.second < right.second) ||
|
||||||
(!lower && left.second > right.second);
|
(!lower && left.second > right.second);
|
||||||
});
|
});
|
||||||
uint64 edge_freq = cpu_freq_pairs[0].second;
|
uint64_t edge_freq = cpu_freq_pairs[0].second;
|
||||||
|
|
||||||
std::set<int> inferred_cores;
|
std::set<int> inferred_cores;
|
||||||
for (const auto& cpu_freq_pair : cpu_freq_pairs) {
|
for (const auto& cpu_freq_pair : cpu_freq_pairs) {
|
||||||
|
|
|
@ -89,12 +89,12 @@ void ImageFrameToYUVImage(const ImageFrame& image_frame, YUVImage* yuv_image) {
|
||||||
const int uv_stride = (uv_width + 15) & ~15;
|
const int uv_stride = (uv_width + 15) & ~15;
|
||||||
const int y_size = y_stride * height;
|
const int y_size = y_stride * height;
|
||||||
const int uv_size = uv_stride * uv_height;
|
const int uv_size = uv_stride * uv_height;
|
||||||
uint8* data =
|
uint8_t* data =
|
||||||
reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size * 2, 16));
|
reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size * 2, 16));
|
||||||
std::function<void()> deallocate = [data]() { aligned_free(data); };
|
std::function<void()> deallocate = [data]() { aligned_free(data); };
|
||||||
uint8* y = data;
|
uint8_t* y = data;
|
||||||
uint8* u = y + y_size;
|
uint8_t* u = y + y_size;
|
||||||
uint8* v = u + uv_size;
|
uint8_t* v = u + uv_size;
|
||||||
yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, //
|
yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, //
|
||||||
y, y_stride, //
|
y, y_stride, //
|
||||||
u, uv_stride, //
|
u, uv_stride, //
|
||||||
|
@ -123,10 +123,11 @@ void ImageFrameToYUVNV12Image(const ImageFrame& image_frame,
|
||||||
const int uv_stride = y_stride;
|
const int uv_stride = y_stride;
|
||||||
const int uv_height = (height + 1) / 2;
|
const int uv_height = (height + 1) / 2;
|
||||||
const int uv_size = uv_stride * uv_height;
|
const int uv_size = uv_stride * uv_height;
|
||||||
uint8* data = reinterpret_cast<uint8*>(aligned_malloc(y_size + uv_size, 16));
|
uint8_t* data =
|
||||||
|
reinterpret_cast<uint8_t*>(aligned_malloc(y_size + uv_size, 16));
|
||||||
std::function<void()> deallocate = [data] { aligned_free(data); };
|
std::function<void()> deallocate = [data] { aligned_free(data); };
|
||||||
uint8* y = data;
|
uint8_t* y = data;
|
||||||
uint8* uv = y + y_size;
|
uint8_t* uv = y + y_size;
|
||||||
yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv,
|
yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv,
|
||||||
uv_stride, nullptr, 0, width, height);
|
uv_stride, nullptr, 0, width, height);
|
||||||
const int rv = libyuv::I420ToNV12(
|
const int rv = libyuv::I420ToNV12(
|
||||||
|
@ -210,44 +211,44 @@ void YUVImageToImageFrameFromFormat(const YUVImage& yuv_image,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SrgbToMpegYCbCr(const uint8 r, const uint8 g, const uint8 b, //
|
void SrgbToMpegYCbCr(const uint8_t r, const uint8_t g, const uint8_t b, //
|
||||||
uint8* y, uint8* cb, uint8* cr) {
|
uint8_t* y, uint8_t* cb, uint8_t* cr) {
|
||||||
// ITU-R BT.601 conversion from sRGB to YCbCr.
|
// ITU-R BT.601 conversion from sRGB to YCbCr.
|
||||||
// FastIntRound is used rather than SafeRound since the possible
|
// FastIntRound is used rather than SafeRound since the possible
|
||||||
// range of values is [16,235] for Y and [16,240] for Cb and Cr and we
|
// range of values is [16,235] for Y and [16,240] for Cb and Cr and we
|
||||||
// don't care about the rounding direction for values exactly between
|
// don't care about the rounding direction for values exactly between
|
||||||
// two integers.
|
// two integers.
|
||||||
*y = static_cast<uint8>(
|
*y = static_cast<uint8_t>(
|
||||||
mediapipe::MathUtil::FastIntRound(16.0 + //
|
mediapipe::MathUtil::FastIntRound(16.0 + //
|
||||||
65.481 * r / 255.0 + //
|
65.481 * r / 255.0 + //
|
||||||
128.553 * g / 255.0 + //
|
128.553 * g / 255.0 + //
|
||||||
24.966 * b / 255.0));
|
24.966 * b / 255.0));
|
||||||
*cb = static_cast<uint8>(
|
*cb = static_cast<uint8_t>(
|
||||||
mediapipe::MathUtil::FastIntRound(128.0 + //
|
mediapipe::MathUtil::FastIntRound(128.0 + //
|
||||||
-37.797 * r / 255.0 + //
|
-37.797 * r / 255.0 + //
|
||||||
-74.203 * g / 255.0 + //
|
-74.203 * g / 255.0 + //
|
||||||
112.0 * b / 255.0));
|
112.0 * b / 255.0));
|
||||||
*cr = static_cast<uint8>(
|
*cr = static_cast<uint8_t>(
|
||||||
mediapipe::MathUtil::FastIntRound(128.0 + //
|
mediapipe::MathUtil::FastIntRound(128.0 + //
|
||||||
112.0 * r / 255.0 + //
|
112.0 * r / 255.0 + //
|
||||||
-93.786 * g / 255.0 + //
|
-93.786 * g / 255.0 + //
|
||||||
-18.214 * b / 255.0));
|
-18.214 * b / 255.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
|
void MpegYCbCrToSrgb(const uint8_t y, const uint8_t cb, const uint8_t cr, //
|
||||||
uint8* r, uint8* g, uint8* b) {
|
uint8_t* r, uint8_t* g, uint8_t* b) {
|
||||||
// ITU-R BT.601 conversion from YCbCr to sRGB
|
// ITU-R BT.601 conversion from YCbCr to sRGB
|
||||||
// Use SafeRound since many MPEG YCbCr values do not correspond directly
|
// Use SafeRound since many MPEG YCbCr values do not correspond directly
|
||||||
// to an sRGB value.
|
// to an sRGB value.
|
||||||
*r = mediapipe::MathUtil::SafeRound<uint8, double>( //
|
*r = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
|
||||||
255.0 / 219.0 * (y - 16.0) + //
|
255.0 / 219.0 * (y - 16.0) + //
|
||||||
255.0 / 112.0 * 0.701 * (cr - 128.0));
|
255.0 / 112.0 * 0.701 * (cr - 128.0));
|
||||||
*g = mediapipe::MathUtil::SafeRound<uint8, double>(
|
*g = mediapipe::MathUtil::SafeRound<uint8_t, double>(
|
||||||
255.0 / 219.0 * (y - 16.0) - //
|
255.0 / 219.0 * (y - 16.0) - //
|
||||||
255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - //
|
255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - //
|
||||||
255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0));
|
255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0));
|
||||||
*b = mediapipe::MathUtil::SafeRound<uint8, double>( //
|
*b = mediapipe::MathUtil::SafeRound<uint8_t, double>( //
|
||||||
255.0 / 219.0 * (y - 16.0) + //
|
255.0 / 219.0 * (y - 16.0) + //
|
||||||
255.0 / 112.0 * 0.886 * (cb - 128.0));
|
255.0 / 112.0 * 0.886 * (cb - 128.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,15 +261,15 @@ void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, //
|
||||||
|
|
||||||
cv::Mat GetSrgbToLinearRgb16Lut() {
|
cv::Mat GetSrgbToLinearRgb16Lut() {
|
||||||
cv::Mat lut(1, 256, CV_16UC1);
|
cv::Mat lut(1, 256, CV_16UC1);
|
||||||
uint16* ptr = lut.ptr<uint16>();
|
uint16_t* ptr = lut.ptr<uint16_t>();
|
||||||
constexpr double kUint8Max = 255.0;
|
constexpr double kUint8Max = 255.0;
|
||||||
constexpr double kUint16Max = 65535.0;
|
constexpr double kUint16Max = 65535.0;
|
||||||
for (int i = 0; i < 256; ++i) {
|
for (int i = 0; i < 256; ++i) {
|
||||||
if (i < 0.04045 * kUint8Max) {
|
if (i < 0.04045 * kUint8Max) {
|
||||||
ptr[i] = static_cast<uint16>(
|
ptr[i] = static_cast<uint16_t>(
|
||||||
(static_cast<double>(i) / kUint8Max / 12.92) * kUint16Max + .5);
|
(static_cast<double>(i) / kUint8Max / 12.92) * kUint16Max + .5);
|
||||||
} else {
|
} else {
|
||||||
ptr[i] = static_cast<uint16>(
|
ptr[i] = static_cast<uint16_t>(
|
||||||
pow((static_cast<double>(i) / kUint8Max + 0.055) / 1.055, 2.4) *
|
pow((static_cast<double>(i) / kUint8Max + 0.055) / 1.055, 2.4) *
|
||||||
kUint16Max +
|
kUint16Max +
|
||||||
.5);
|
.5);
|
||||||
|
@ -279,15 +280,15 @@ cv::Mat GetSrgbToLinearRgb16Lut() {
|
||||||
|
|
||||||
cv::Mat GetLinearRgb16ToSrgbLut() {
|
cv::Mat GetLinearRgb16ToSrgbLut() {
|
||||||
cv::Mat lut(1, 65536, CV_8UC1);
|
cv::Mat lut(1, 65536, CV_8UC1);
|
||||||
uint8* ptr = lut.ptr<uint8>();
|
uint8_t* ptr = lut.ptr<uint8_t>();
|
||||||
constexpr double kUint8Max = 255.0;
|
constexpr double kUint8Max = 255.0;
|
||||||
constexpr double kUint16Max = 65535.0;
|
constexpr double kUint16Max = 65535.0;
|
||||||
for (int i = 0; i < 65536; ++i) {
|
for (int i = 0; i < 65536; ++i) {
|
||||||
if (i < 0.0031308 * kUint16Max) {
|
if (i < 0.0031308 * kUint16Max) {
|
||||||
ptr[i] = static_cast<uint8>(
|
ptr[i] = static_cast<uint8_t>(
|
||||||
(static_cast<double>(i) / kUint16Max * 12.92) * kUint8Max + .5);
|
(static_cast<double>(i) / kUint16Max * 12.92) * kUint8Max + .5);
|
||||||
} else {
|
} else {
|
||||||
ptr[i] = static_cast<uint8>(
|
ptr[i] = static_cast<uint8_t>(
|
||||||
(1.055 * pow(static_cast<double>(i) / kUint16Max, 1.0 / 2.4) - .055) *
|
(1.055 * pow(static_cast<double>(i) / kUint16Max, 1.0 / 2.4) - .055) *
|
||||||
kUint8Max +
|
kUint8Max +
|
||||||
.5);
|
.5);
|
||||||
|
@ -306,13 +307,13 @@ void LinearRgb16ToSrgb(const cv::Mat& source, cv::Mat* destination) {
|
||||||
destination->create(source.size(), CV_8UC(source.channels()));
|
destination->create(source.size(), CV_8UC(source.channels()));
|
||||||
|
|
||||||
static const cv::Mat kLut = GetLinearRgb16ToSrgbLut();
|
static const cv::Mat kLut = GetLinearRgb16ToSrgbLut();
|
||||||
const uint8* lookup_table_ptr = kLut.ptr<uint8>();
|
const uint8_t* lookup_table_ptr = kLut.ptr<uint8_t>();
|
||||||
const int num_channels = source.channels();
|
const int num_channels = source.channels();
|
||||||
for (int row = 0; row < source.rows; ++row) {
|
for (int row = 0; row < source.rows; ++row) {
|
||||||
for (int col = 0; col < source.cols; ++col) {
|
for (int col = 0; col < source.cols; ++col) {
|
||||||
for (int channel = 0; channel < num_channels; ++channel) {
|
for (int channel = 0; channel < num_channels; ++channel) {
|
||||||
uint8* ptr = destination->ptr<uint8>(row);
|
uint8_t* ptr = destination->ptr<uint8_t>(row);
|
||||||
const uint16* ptr16 = source.ptr<uint16>(row);
|
const uint16_t* ptr16 = source.ptr<uint16_t>(row);
|
||||||
ptr[col * num_channels + channel] =
|
ptr[col * num_channels + channel] =
|
||||||
lookup_table_ptr[ptr16[col * num_channels + channel]];
|
lookup_table_ptr[ptr16[col * num_channels + channel]];
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,14 +43,14 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) {
|
||||||
|
|
||||||
Packet MakeImageFramePacket(cv::Mat input, int timestamp) {
|
Packet MakeImageFramePacket(cv::Mat input, int timestamp) {
|
||||||
ImageFrame input_image(GetImageFormat(input.channels()), input.cols,
|
ImageFrame input_image(GetImageFormat(input.channels()), input.cols,
|
||||||
input.rows, input.step, input.data, [](uint8*) {});
|
input.rows, input.step, input.data, [](uint8_t*) {});
|
||||||
return MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0));
|
return MakePacket<ImageFrame>(std::move(input_image)).At(Timestamp(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
Packet MakeImagePacket(cv::Mat input, int timestamp) {
|
Packet MakeImagePacket(cv::Mat input, int timestamp) {
|
||||||
mediapipe::Image input_image(std::make_shared<mediapipe::ImageFrame>(
|
mediapipe::Image input_image(std::make_shared<mediapipe::ImageFrame>(
|
||||||
GetImageFormat(input.channels()), input.cols, input.rows, input.step,
|
GetImageFormat(input.channels()), input.cols, input.rows, input.step,
|
||||||
input.data, [](uint8*) {}));
|
input.data, [](uint8_t*) {}));
|
||||||
return MakePacket<mediapipe::Image>(std::move(input_image)).At(Timestamp(0));
|
return MakePacket<mediapipe::Image>(std::move(input_image)).At(Timestamp(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles(
|
absl::StatusOr<proto_ns::Map<int64_t, LabelMapItem>> BuildLabelMapFromFiles(
|
||||||
absl::string_view labels_file_contents,
|
absl::string_view labels_file_contents,
|
||||||
absl::string_view display_names_file) {
|
absl::string_view display_names_file) {
|
||||||
if (labels_file_contents.empty()) {
|
if (labels_file_contents.empty()) {
|
||||||
|
@ -68,7 +68,7 @@ absl::StatusOr<proto_ns::Map<int64, LabelMapItem>> BuildLabelMapFromFiles(
|
||||||
label_map_items[i].set_display_name(display_names[i]);
|
label_map_items[i].set_display_name(display_names[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
proto_ns::Map<int64, LabelMapItem> label_map;
|
proto_ns::Map<int64_t, LabelMapItem> label_map;
|
||||||
for (int i = 0; i < label_map_items.size(); ++i) {
|
for (int i = 0; i < label_map_items.size(); ++i) {
|
||||||
label_map[i] = label_map_items[i];
|
label_map[i] = label_map_items[i];
|
||||||
}
|
}
|
||||||
|
|
4
third_party/flatbuffers/BUILD.bazel
vendored
4
third_party/flatbuffers/BUILD.bazel
vendored
|
@ -45,12 +45,16 @@ filegroup(
|
||||||
"include/flatbuffers/bfbs_generator.h",
|
"include/flatbuffers/bfbs_generator.h",
|
||||||
"include/flatbuffers/buffer.h",
|
"include/flatbuffers/buffer.h",
|
||||||
"include/flatbuffers/buffer_ref.h",
|
"include/flatbuffers/buffer_ref.h",
|
||||||
|
"include/flatbuffers/code_generator.h",
|
||||||
"include/flatbuffers/code_generators.h",
|
"include/flatbuffers/code_generators.h",
|
||||||
"include/flatbuffers/default_allocator.h",
|
"include/flatbuffers/default_allocator.h",
|
||||||
"include/flatbuffers/detached_buffer.h",
|
"include/flatbuffers/detached_buffer.h",
|
||||||
"include/flatbuffers/flatbuffer_builder.h",
|
"include/flatbuffers/flatbuffer_builder.h",
|
||||||
"include/flatbuffers/flatbuffers.h",
|
"include/flatbuffers/flatbuffers.h",
|
||||||
|
"include/flatbuffers/flatc.h",
|
||||||
|
"include/flatbuffers/flex_flat_util.h",
|
||||||
"include/flatbuffers/flexbuffers.h",
|
"include/flatbuffers/flexbuffers.h",
|
||||||
|
"include/flatbuffers/grpc.h",
|
||||||
"include/flatbuffers/hash.h",
|
"include/flatbuffers/hash.h",
|
||||||
"include/flatbuffers/idl.h",
|
"include/flatbuffers/idl.h",
|
||||||
"include/flatbuffers/minireflect.h",
|
"include/flatbuffers/minireflect.h",
|
||||||
|
|
8
third_party/flatbuffers/workspace.bzl
vendored
8
third_party/flatbuffers/workspace.bzl
vendored
|
@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive")
|
||||||
def repo():
|
def repo():
|
||||||
third_party_http_archive(
|
third_party_http_archive(
|
||||||
name = "flatbuffers",
|
name = "flatbuffers",
|
||||||
strip_prefix = "flatbuffers-2.0.6",
|
strip_prefix = "flatbuffers-23.1.21",
|
||||||
sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9",
|
sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238",
|
||||||
urls = [
|
urls = [
|
||||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v2.0.6.tar.gz",
|
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
||||||
"https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz",
|
"https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz",
|
||||||
],
|
],
|
||||||
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
build_file = "//third_party/flatbuffers:BUILD.bazel",
|
||||||
delete = ["build_defs.bzl", "BUILD.bazel"],
|
delete = ["build_defs.bzl", "BUILD.bazel"],
|
||||||
|
|
48
third_party/wasm_files.bzl
vendored
48
third_party/wasm_files.bzl
vendored
|
@ -12,72 +12,72 @@ def wasm_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
|
name = "com_google_mediapipe_wasm_audio_wasm_internal_js",
|
||||||
sha256 = "0eca68e2291a548b734bcab5db4c9e6b997e852ea7e19228003b9e2a78c7c646",
|
sha256 = "b810de53d7ccf991b9c70fcdf7e88b5c3f2942ae766436f22be48159b6a7e687",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681328323089931"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681849488227617"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
|
name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm",
|
||||||
sha256 = "69bc95af5b783b510ec1842d6fb9594254907d8e1334799c5753164878a7dcac",
|
sha256 = "26d91147e5c6c8a92e0a4ebf59599068a3cff6108847b793ef33ac23e98eddb9",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681328325829340"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681849491546937"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js",
|
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js",
|
||||||
sha256 = "88a0176cc80d6a1eb175a5105df705cf8b8684cf13f6db0a264af0b67b65a22a",
|
sha256 = "b38e37b3024692558eaaba159921fedd3297d1a09bba1c16a06fed327845b0bd",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681328328330829"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681849494099698"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm",
|
name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm",
|
||||||
sha256 = "1cc0c3db7d252801be4b090d8bbba61f308cc3dd5efe197319581d3af29495c7",
|
sha256 = "6a8e73d2e926565046e16adf1748f0f8ec5135fafe7eb8b9c83892e64c1a449a",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681328331085637"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681849496451970"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_text_wasm_internal_js",
|
name = "com_google_mediapipe_wasm_text_wasm_internal_js",
|
||||||
sha256 = "d9cd100b6d330d36f7749fe5fc64a2cdd0abb947a0376e6140784cfb0361a4e2",
|
sha256 = "785cba67b623b1dc66dc3621e97fd6b30edccbb408184a3094d0aa68ddd5becb",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681328333442454"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681849498746265"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
|
name = "com_google_mediapipe_wasm_text_wasm_internal_wasm",
|
||||||
sha256 = "30a2fcca630bdad6e99173ea7d0d8c5d7086aedf393d0159fa05bf9d08d4ff65",
|
sha256 = "a858b8a2e8b40e9c936b66566c5aefd396536c4e936459ab9ae7e239621adc14",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681328335803336"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681849501370461"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js",
|
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js",
|
||||||
sha256 = "70ca2bd15c56e0ce7bb10ff2188b4a1f9eafbb657eb9424e4cab8d7b29179871",
|
sha256 = "5292f1442d5e5c037e7cffb78a8c2d71255348ca2c3bd759b314bdbedd5590c2",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681328338162884"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681849503379116"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm",
|
name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm",
|
||||||
sha256 = "8221b385905f36a769d7731a0adbe18b681bcb873561890429ca84278c67c3fd",
|
sha256 = "e44b48ab29ee1d8befec804e9a63445c56266b679d19fb476d556ca621f0e493",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681328340808115"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681849505997020"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
|
name = "com_google_mediapipe_wasm_vision_wasm_internal_js",
|
||||||
sha256 = "07692acd8202adafebd35dbcd7e2b8e88a76d4a0e6b9229cb3cad59503eeddc7",
|
sha256 = "205855eba70464a92b9d00e90acac15c51a9f76192f900e697304ac6dea8f714",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681328343147709"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681849508414277"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
|
name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm",
|
||||||
sha256 = "03bf553fa6a768b0d70103a5e7d835b6b37371ff44e201c3392f22e0879737c3",
|
sha256 = "c0cbd0df3adb2a9cd1331d14f522d2bae9f8adc9f1b35f92cbbc4b782b190cef",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681328345605574"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681849510936608"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js",
|
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js",
|
||||||
sha256 = "36697be14f921985eac15d1447ec8a260817b05ade1c9bb3ca7e906e0f047ec0",
|
sha256 = "0969812de4d3573198fa2eba4f5b0a7e97e98f97bd4215d876543f4925e57b84",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681328348025082"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681849513292639"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm",
|
name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm",
|
||||||
sha256 = "103fb145438d61cfecb2e8db3f06b43a5d77a7e3fcea940437fe272227cf2592",
|
sha256 = "f2ab62c3f8dabab0a573dadf5c105ff81a03c29c70f091f8cf273ae030c0a86f",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681328350709881"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681849515999000"],
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user