diff --git a/WORKSPACE b/WORKSPACE index 6e079f142..760898185 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -239,6 +239,16 @@ http_archive( repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"}, ) +http_archive( + name = "darts_clone", + build_file = "@//third_party:darts_clone.BUILD", + sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c", + strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983", + urls = [ + "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", + ], +) + http_archive( name = "org_tensorflow_text", sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8", diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.cc b/mediapipe/calculators/core/constant_side_packet_calculator.cc index 45ff07110..509f7e9dd 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator.cc @@ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase { } else if (packet_options.has_string_value()) { packet.Set(); } else if (packet_options.has_uint64_value()) { - packet.Set(); + packet.Set(); } else if (packet_options.has_classification_list_value()) { packet.Set(); } else if (packet_options.has_landmark_list_value()) { @@ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase { } else if (packet_options.has_string_value()) { packet.Set(MakePacket(packet_options.string_value())); } else if (packet_options.has_uint64_value()) { - packet.Set(MakePacket(packet_options.uint64_value())); + packet.Set(MakePacket(packet_options.uint64_value())); } else if (packet_options.has_classification_list_value()) { packet.Set(MakePacket( packet_options.classification_list_value())); diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index c734ddb5f..192019820 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test { } // Use this when ALLOW/DISALLOW input is provided as a side packet. - void RunTimeStep(int64 timestamp, bool stream_payload) { + void RunTimeStep(int64_t timestamp, bool stream_payload) { runner_->MutableInputs()->Get("", 0).packets.push_back( MakePacket(stream_payload).At(Timestamp(timestamp))); MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; } // Use this when ALLOW/DISALLOW input is provided as an input stream. - void RunTimeStep(int64 timestamp, const std::string& control_tag, + void RunTimeStep(int64_t timestamp, const std::string& control_tag, bool control) { runner_->MutableInputs()->Get("", 0).packets.push_back( MakePacket(true).At(Timestamp(timestamp))); @@ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) { } )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) { } )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) { output_stream: "test_output" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { )"); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) { )"); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) { )"); runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) { )"); runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true)); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) { output_stream: "test_output" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "ALLOW", true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "ALLOW", false); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "ALLOW", true); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "ALLOW", false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) { output_stream: "test_output" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "DISALLOW", true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "DISALLOW", false); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "DISALLOW", true); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "DISALLOW", false); const std::vector& output = runner()->Outputs().Get("", 0).packets; @@ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "ALLOW", false); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "ALLOW", true); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "ALLOW", true); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "ALLOW", false); const std::vector& output = @@ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "DISALLOW", true); - constexpr int64 kTimestampValue1 = 43; + constexpr int64_t kTimestampValue1 = 43; RunTimeStep(kTimestampValue1, "DISALLOW", false); - constexpr int64 kTimestampValue2 = 44; + constexpr int64_t kTimestampValue2 = 44; RunTimeStep(kTimestampValue2, "DISALLOW", false); - constexpr int64 kTimestampValue3 = 45; + constexpr int64_t kTimestampValue3 = 45; RunTimeStep(kTimestampValue3, "DISALLOW", true); const std::vector& output = @@ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "DISALLOW", false); const std::vector& output = @@ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) { output_stream: "STATE_CHANGE:state_changed" )"); - constexpr int64 kTimestampValue0 = 42; + constexpr int64_t kTimestampValue0 = 42; RunTimeStep(kTimestampValue0, "ALLOW", true); const std::vector& output = diff --git a/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc b/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc index 1f994cbed..8b4254cbc 100644 --- a/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc +++ b/mediapipe/calculators/core/matrix_to_vector_calculator_test.cc @@ -35,7 +35,7 @@ class MatrixToVectorCalculatorTest void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; } void AppendInput(const std::vector& column_major_data, - int64 timestamp) { + int64_t timestamp) { ASSERT_EQ(num_input_samples_ * num_input_channels_, column_major_data.size()); Eigen::Map data_map(&column_major_data[0], diff --git a/mediapipe/calculators/core/packet_resampler_calculator_test.cc b/mediapipe/calculators/core/packet_resampler_calculator_test.cc index f02da0d18..d80793da4 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator_test.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator_test.cc @@ -51,9 +51,9 @@ class SimpleRunner : public CalculatorRunner { virtual ~SimpleRunner() {} - void SetInput(const std::vector& timestamp_list) { + void SetInput(const std::vector& timestamp_list) { MutableInputs()->Index(0).packets.clear(); - for (const int64 ts : timestamp_list) { + for (const int64_t ts : timestamp_list) { MutableInputs()->Index(0).packets.push_back( Adopt(new std::string(absl::StrCat("Frame #", ts))) .At(Timestamp(ts))); @@ -72,8 +72,8 @@ class SimpleRunner : public CalculatorRunner { } void CheckOutputTimestamps( - const std::vector& expected_frames, - const std::vector& expected_timestamps) const { + const std::vector& expected_frames, + const std::vector& expected_timestamps) const { EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size()); EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size()); int count = 0; @@ -112,7 +112,7 @@ MATCHER_P2(PacketAtTimestamp, payload, timestamp, *result_listener << "at incorrect timestamp = " << arg.Timestamp().Value(); return false; } - int64 actual_payload = arg.template Get(); + int64_t actual_payload = arg.template Get(); if (actual_payload != payload) { *result_listener << "with incorrect payload = " << actual_payload; return false; @@ -137,18 +137,18 @@ class ReproducibleJitterWithReflectionStrategyForTesting // // An EXPECT will fail if sequence is less than the number requested during // processing. - static std::vector random_sequence; + static std::vector random_sequence; protected: - virtual uint64 GetNextRandom(uint64 n) { + virtual uint64_t GetNextRandom(uint64_t n) { EXPECT_LT(sequence_index_, random_sequence.size()); return random_sequence[sequence_index_++] % n; } private: - int32 sequence_index_ = 0; + int32_t sequence_index_ = 0; }; -std::vector +std::vector ReproducibleJitterWithReflectionStrategyForTesting::random_sequence; // PacketResamplerCalculator child class which injects a specified stream @@ -469,7 +469,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) { } )pb")); - for (const int64 ts : {0, 5000, 10010, 15001, 19990}) { + for (const int64_t ts : {0, 5000, 10010, 15001, 19990}) { runner.MutableInputs()->Tag(kDataTag).packets.push_back( Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); } diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc index 563417669..d8c358909 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator_test.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -43,8 +43,8 @@ constexpr char kDisallowTag[] = "DISALLOW"; // Returns the timestamp values for a vector of Packets. // TODO: puth this kind of test util in a common place. -std::vector TimestampValues(const std::vector& packets) { - std::vector result; +std::vector TimestampValues(const std::vector& packets) { + std::vector result; for (const Packet& packet : packets) { result.push_back(packet.Timestamp().Value()); } @@ -371,7 +371,7 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) { for (int main_ts = 0; main_ts < 50; ++main_ts) { send_packet("in", main_ts); MP_EXPECT_OK(graph_.WaitUntilIdle()); - std::vector ts_values = TimestampValues(outputs); + std::vector ts_values = TimestampValues(outputs); EXPECT_EQ(ts_values.size(), main_ts + 1); for (int j = 0; j < main_ts + 1; ++j) { EXPECT_EQ(ts_values[j], j); diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc index ed89889df..311f7d815 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc @@ -121,7 +121,7 @@ absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag(kTagAtTimestamp)) { RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries()) << "For AT_TIMESTAMP tag, 2 input side packets are required."; - cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set(); + cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set(); } else { RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries()) << "Same number of input side packets and output streams is required."; @@ -178,8 +178,8 @@ absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); } } else if (cc->Outputs().HasTag(kTagAtTimestamp)) { - int64 timestamp = - cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get(); + int64_t timestamp = + cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get(); for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { cc->Outputs() .Get(output_tag_, i) diff --git a/mediapipe/calculators/core/string_to_int_calculator.cc b/mediapipe/calculators/core/string_to_int_calculator.cc index ecd55afb6..fa67aa8e5 100644 --- a/mediapipe/calculators/core/string_to_int_calculator.cc +++ b/mediapipe/calculators/core/string_to_int_calculator.cc @@ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator); using StringToUintCalculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUintCalculator); -using StringToInt32Calculator = StringToIntCalculatorTemplate; +using StringToInt32Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToInt32Calculator); -using StringToUint32Calculator = StringToIntCalculatorTemplate; +using StringToUint32Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUint32Calculator); -using StringToInt64Calculator = StringToIntCalculatorTemplate; +using StringToInt64Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToInt64Calculator); -using StringToUint64Calculator = StringToIntCalculatorTemplate; +using StringToUint64Calculator = StringToIntCalculatorTemplate; REGISTER_CALCULATOR(StringToUint64Calculator); } // namespace mediapipe diff --git a/mediapipe/calculators/image/warp_affine_calculator.cc b/mediapipe/calculators/image/warp_affine_calculator.cc index 388701773..dcc371036 100644 --- a/mediapipe/calculators/image/warp_affine_calculator.cc +++ b/mediapipe/calculators/image/warp_affine_calculator.cc @@ -166,7 +166,7 @@ class WarpAffineRunnerHolder { const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(), frame_ptr->Height(), frame_ptr->WidthStep(), const_cast(frame_ptr->PixelData()), - [](uint8* data){}); + [](uint8_t* data){}); ASSIGN_OR_RETURN(auto result, runner->Run(image_frame, matrix, size, border_mode)); return mediapipe::Image(std::make_shared(std::move(result))); diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index a574e11e1..5d52dda0f 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -401,8 +401,8 @@ cc_library_with_tflite( hdrs = ["inference_calculator.h"], tflite_deps = [ "//mediapipe/util/tflite:tflite_model_loader", - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], deps = [ ":inference_calculator_cc_proto", @@ -506,7 +506,7 @@ cc_library_with_tflite( name = "tflite_delegate_ptr", hdrs = ["tflite_delegate_ptr.h"], tflite_deps = [ - "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", + "@org_tensorflow//tensorflow/lite/c:c_api_types", ], ) @@ -517,8 +517,8 @@ cc_library_with_tflite( tflite_deps = [ ":tflite_delegate_ptr", "//mediapipe/util/tflite:tflite_model_loader", - "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/c:c_api_types", ], deps = [ ":inference_runner", @@ -546,8 +546,8 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/c:c_api_types", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", ] + select({ "//conditions:default": [], diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index c53c6e3d5..2a6936eba 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) { return kSideInCustomOpResolver(cc).As(); } return PacketAdopting( - std::make_unique()); + std::make_unique< + tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>()); } } // namespace api2 diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h index b73f42053..d7e5e98cf 100644 --- a/mediapipe/calculators/tensor/inference_calculator.h +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -26,7 +26,7 @@ #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/util/tflite/tflite_model_loader.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow/lite/kernels/register.h" namespace mediapipe { namespace api2 { @@ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf { // Deprecated. Prefers to use "OP_RESOLVER" input side packet instead. // TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the // migration. - static constexpr SideInput:: - Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; + static constexpr SideInput::Optional + kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; static constexpr SideInput::Optional kSideInOpResolver{ "OP_RESOLVER"}; static constexpr SideInput::Optional kSideInModel{"MODEL"}; diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index 80d36ba68..0c22fc7a1 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -24,7 +24,7 @@ #include "mediapipe/calculators/tensor/inference_calculator_utils.h" #include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h" #include "mediapipe/calculators/tensor/inference_runner.h" -#include "tensorflow/lite/core/shims/cc/interpreter.h" +#include "tensorflow/lite/interpreter.h" #if defined(MEDIAPIPE_ANDROID) #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #endif // ANDROID diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc index 26bf3d8f8..a2b8a9285 100644 --- a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.cc @@ -22,9 +22,9 @@ #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/port/ret_check.h" -#include "tensorflow/lite/core/shims/c/c_api_types.h" -#include "tensorflow/lite/core/shims/cc/interpreter.h" -#include "tensorflow/lite/core/shims/cc/interpreter_builder.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" #include "tensorflow/lite/string_util.h" #define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe @@ -33,8 +33,8 @@ namespace mediapipe { namespace { -using Interpreter = ::tflite_shims::Interpreter; -using InterpreterBuilder = ::tflite_shims::InterpreterBuilder; +using Interpreter = ::tflite::Interpreter; +using InterpreterBuilder = ::tflite::InterpreterBuilder; template void CopyTensorBufferToInterpreter(const Tensor& input_tensor, diff --git a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h index d2035c994..ca6d79851 100644 --- a/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h +++ b/mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h @@ -23,8 +23,8 @@ #include "mediapipe/calculators/tensor/tflite_delegate_ptr.h" #include "mediapipe/framework/api2/packet.h" #include "mediapipe/util/tflite/tflite_model_loader.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/c/c_api_types.h" namespace mediapipe { diff --git a/mediapipe/calculators/tensor/tflite_delegate_ptr.h b/mediapipe/calculators/tensor/tflite_delegate_ptr.h index caaa19c2f..afaf9f515 100644 --- a/mediapipe/calculators/tensor/tflite_delegate_ptr.h +++ b/mediapipe/calculators/tensor/tflite_delegate_ptr.h @@ -18,7 +18,7 @@ #include #include -#include "tensorflow/lite/core/shims/c/c_api_types.h" +#include "tensorflow/lite/c/c_api_types.h" namespace mediapipe { diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index 4a47b7d7f..2608b1c5b 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -61,12 +61,12 @@ constexpr char kSessionBundleTag[] = "SESSION_BUNDLE"; // overload GPU/TPU/... class SimpleSemaphore { public: - explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {} + explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {} SimpleSemaphore(const SimpleSemaphore&) = delete; SimpleSemaphore(SimpleSemaphore&&) = delete; // Acquires the semaphore by certain amount. - void Acquire(uint32 amount) { + void Acquire(uint32_t amount) { mutex_.Lock(); while (count_ < amount) { cond_.Wait(&mutex_); @@ -76,7 +76,7 @@ class SimpleSemaphore { } // Releases the semaphore by certain amount. - void Release(uint32 amount) { + void Release(uint32_t amount) { mutex_.Lock(); count_ += amount; cond_.SignalAll(); @@ -84,7 +84,7 @@ class SimpleSemaphore { } private: - uint32 count_; + uint32_t count_; absl::Mutex mutex_; absl::CondVar cond_; }; @@ -488,7 +488,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { // necessary. absl::Status OutputBatch(CalculatorContext* cc, std::unique_ptr inference_state) { - const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow()); std::vector> input_tensors; for (auto& keyed_tensors : inference_state->input_tensor_batches_) { @@ -544,7 +544,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { get_session_run_throttle(options_.max_concurrent_session_runs()); session_run_throttle->Acquire(1); } - const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow()); tf::Status tf_status; { #if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) @@ -562,7 +562,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { // informative error message. RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); - const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow()); cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) ->IncrementBy(run_end_time - run_start_time); cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); @@ -611,7 +611,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { } // Get end time and report. - const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); + const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow()); cc->GetCounter(kTotalUsecsCounterSuffix) ->IncrementBy(end_time - start_time); cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) @@ -650,7 +650,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { // The static singleton semaphore to throttle concurrent session runs. static SimpleSemaphore* get_session_run_throttle( - int32 max_concurrent_session_runs) { + int32_t max_concurrent_session_runs) { static SimpleSemaphore* session_run_throttle = new SimpleSemaphore(max_concurrent_session_runs); return session_run_throttle; diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 0d1d4ca26..a14c6bd95 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -197,15 +197,15 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // timestamp and the associated feature. This information is used in process // to output batches of packets in order. timestamps_.clear(); - int64 last_timestamp_seen = Timestamp::PreStream().Value(); + int64_t last_timestamp_seen = Timestamp::PreStream().Value(); first_timestamp_seen_ = Timestamp::OneOverPostStream().Value(); for (const auto& map_kv : sequence_->feature_lists().feature_list()) { if (absl::StrContains(map_kv.first, "/timestamp")) { LOG(INFO) << "Found feature timestamps: " << map_kv.first << " with size: " << map_kv.second.feature_size(); - int64 recent_timestamp = Timestamp::PreStream().Value(); + int64_t recent_timestamp = Timestamp::PreStream().Value(); for (int i = 0; i < map_kv.second.feature_size(); ++i) { - int64 next_timestamp = + int64_t next_timestamp = mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0); RET_CHECK_GT(next_timestamp, recent_timestamp) << "Timestamps must be sequential. If you're seeing this message " @@ -361,8 +361,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // any particular call to Process(). At the every end, we output the // poststream packets. If we only have poststream packets, // last_timestamp_key_ will be empty. - int64 start_timestamp = 0; - int64 end_timestamp = 0; + int64_t start_timestamp = 0; + int64_t end_timestamp = 0; if (last_timestamp_key_.empty() || process_poststream_) { process_poststream_ = true; start_timestamp = Timestamp::PostStream().Value(); @@ -481,14 +481,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { // Store a map from the keys for each stream to the timestamps for each // key. This allows us to identify which packets to output for each stream // for timestamps within a given time window. - std::map> timestamps_; + std::map> timestamps_; // Store the stream with the latest timestamp in the SequenceExample. std::string last_timestamp_key_; // Store the index of the current timestamp. Will be less than // timestamps_[last_timestamp_key_].size(). int current_timestamp_index_; // Store the very first timestamp, so we output everything on the first frame. - int64 first_timestamp_seen_; + int64_t first_timestamp_seen_; // List of keypoint names. std::vector keypoint_names_; // Default keypoint location when missing. diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc index aadce3615..a4f98d2e9 100644 --- a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator_test.cc @@ -54,7 +54,7 @@ class VectorToTensorFloatCalculatorTest : public ::testing::Test { } } - const int64 time = 1234; + const int64_t time = 1234; runner_->MutableInputs()->Index(0).packets.push_back( Adopt(input.release()).At(Timestamp(time))); @@ -91,7 +91,7 @@ TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) { // 2^i can be represented exactly in floating point numbers if 'i' is small. input->at(i) = static_cast(1 << i); } - const int64 time = 1234; + const int64_t time = 1234; runner_->MutableInputs()->Index(0).packets.push_back( Adopt(input.release()).At(Timestamp(time))); diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc index 238bcf8be..c3c920bcf 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc @@ -28,11 +28,8 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -using mediapipe::Adopt; -using mediapipe::CalculatorBase; using mediapipe::ImageFrame; using mediapipe::PacketTypeSet; -using mediapipe::autoflip::Border; constexpr char kDetectedBorders[] = "DETECTED_BORDERS"; constexpr int kMinBorderDistance = 5; diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc index e72d54e55..431e3d161 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator_test.cc @@ -28,16 +28,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" -using mediapipe::Adopt; using mediapipe::CalculatorGraphConfig; using mediapipe::CalculatorRunner; using mediapipe::ImageFormat; using mediapipe::ImageFrame; using mediapipe::Packet; using mediapipe::PacketTypeSet; -using mediapipe::ParseTextProtoOrDie; -using mediapipe::Timestamp; -using mediapipe::autoflip::Border; namespace mediapipe { namespace autoflip { diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc index 9ea79ba44..a45a171a4 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc @@ -31,14 +31,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_matchers.h" -using mediapipe::Adopt; using mediapipe::CalculatorGraphConfig; using mediapipe::CalculatorRunner; using mediapipe::ImageFormat; using mediapipe::ImageFrame; using mediapipe::PacketTypeSet; -using mediapipe::ParseTextProtoOrDie; -using mediapipe::Timestamp; namespace mediapipe { namespace autoflip { diff --git a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc index 85b2d96f8..ba5064a28 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc @@ -28,8 +28,6 @@ using mediapipe::Packet; using mediapipe::PacketTypeSet; using mediapipe::autoflip::DetectionSet; -using mediapipe::autoflip::SalientRegion; -using mediapipe::autoflip::SignalType; constexpr char kIsShotBoundaryTag[] = "IS_SHOT_BOUNDARY"; constexpr char kSignalInputsTag[] = "SIGNAL"; diff --git a/mediapipe/framework/api2/node_test.cc b/mediapipe/framework/api2/node_test.cc index a6c1ef7c6..152cbb0e2 100644 --- a/mediapipe/framework/api2/node_test.cc +++ b/mediapipe/framework/api2/node_test.cc @@ -19,8 +19,6 @@ namespace mediapipe { namespace api2 { namespace test { -using testing::ElementsAre; - // Returns the packet values for a vector of Packets. template std::vector PacketValues(const std::vector& packets) { diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index b59467b9f..8a6d079e3 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -310,7 +310,7 @@ class Scheduler { absl::Mutex state_mutex_; // Current state of the scheduler. - std::atomic state_ = ATOMIC_VAR_INIT(STATE_NOT_STARTED); + std::atomic state_ = STATE_NOT_STARTED; // True if all graph input streams are closed. bool graph_input_streams_closed_ ABSL_GUARDED_BY(state_mutex_) = false; diff --git a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc index d53acedc9..fd51a7383 100644 --- a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler.cc @@ -131,9 +131,9 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) { // Record the most recent first kept timestamp on any stream. for (const auto& stream : input_stream_managers_) { - int32 queue_size = (stream->QueueSize() >= trigger_queue_size_) - ? target_queue_size_ - : trigger_queue_size_ - 1; + int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_) + ? target_queue_size_ + : trigger_queue_size_ - 1; if (stream->QueueSize() > queue_size) { kept_timestamp_ = std::max( kept_timestamp_, stream->GetMinTimestampAmongNLatest(queue_size + 1) @@ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { } private: - int32 trigger_queue_size_; - int32 target_queue_size_; + int32_t trigger_queue_size_; + int32_t target_queue_size_; bool fixed_min_size_; // Indicates that GetNodeReadiness has returned kReadyForProcess once, and // the corresponding call to FillInputSet has not yet completed. diff --git a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc index 4f1367a9a..186d59dfe 100644 --- a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc @@ -30,15 +30,15 @@ namespace mediapipe { namespace { -const int64 kMaxPacketId = 100; -const int64 kSlowCalculatorRate = 10; +const int64_t kMaxPacketId = 100; +const int64_t kSlowCalculatorRate = 10; // Rate limiter for TestSlowCalculator. ABSL_CONST_INIT absl::Mutex g_source_mutex(absl::kConstInit); -int64 g_source_counter ABSL_GUARDED_BY(g_source_mutex); +int64_t g_source_counter ABSL_GUARDED_BY(g_source_mutex); // Rate limiter for TestSourceCalculator. -int64 g_slow_counter ABSL_GUARDED_BY(g_source_mutex); +int64_t g_slow_counter ABSL_GUARDED_BY(g_source_mutex); // Flag that indicates that the source is done. bool g_source_done ABSL_GUARDED_BY(g_source_mutex); @@ -47,7 +47,7 @@ class TestSourceCalculator : public CalculatorBase { public: TestSourceCalculator() : current_packet_id_(0) {} static absl::Status GetContract(CalculatorContract* cc) { - cc->Outputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { @@ -62,7 +62,7 @@ class TestSourceCalculator : public CalculatorBase { g_source_done = true; return tool::StatusStop(); } - cc->Outputs().Index(0).Add(new int64(0), Timestamp(current_packet_id_)); + cc->Outputs().Index(0).Add(new int64_t(0), Timestamp(current_packet_id_)); ++current_packet_id_; { absl::MutexLock lock(&g_source_mutex); @@ -78,7 +78,7 @@ class TestSourceCalculator : public CalculatorBase { return g_source_counter <= kSlowCalculatorRate * g_slow_counter || g_source_counter <= 1; } - int64 current_packet_id_; + int64_t current_packet_id_; }; REGISTER_CALCULATOR(TestSourceCalculator); @@ -87,8 +87,8 @@ class TestSlowCalculator : public CalculatorBase { public: TestSlowCalculator() = default; static absl::Status GetContract(CalculatorContract* cc) { - cc->Inputs().Index(0).Set(); - cc->Outputs().Index(0).Set(); + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) override { @@ -97,7 +97,7 @@ class TestSlowCalculator : public CalculatorBase { return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { - cc->Outputs().Index(0).Add(new int64(0), + cc->Outputs().Index(0).Add(new int64_t(0), cc->Inputs().Index(0).Value().Timestamp()); { absl::MutexLock lock(&g_source_mutex); @@ -118,8 +118,9 @@ class TestSlowCalculator : public CalculatorBase { REGISTER_CALCULATOR(TestSlowCalculator); // Return the values of the timestamps of a vector of Packets. -static std::vector TimestampValues(const std::vector& packets) { - std::vector result; +static std::vector TimestampValues( + const std::vector& packets) { + std::vector result; for (const Packet& p : packets) { result.push_back(p.Timestamp().Value()); } @@ -174,7 +175,7 @@ TEST_P(FixedSizeInputStreamHandlerTest, DropsPackets) { // consumed. In this way, the TestSlowCalculator consumes and outputs only // every tenth packet. EXPECT_EQ(output_packets.size(), 11); - std::vector expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99}; + std::vector expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99}; EXPECT_THAT(TimestampValues(output_packets), testing::ContainerEq(expected_ts)); } @@ -344,18 +345,18 @@ TEST_P(FixedSizeInputStreamHandlerTest, LateArrivalDrop) { if (GetParam()) { EXPECT_THAT(TimestampValues(output_packets[0]), - testing::ContainerEq(std::vector{1, 2, 3, 4, 5, 6})); + testing::ContainerEq(std::vector{1, 2, 3, 4, 5, 6})); EXPECT_THAT(TimestampValues(output_packets[1]), - testing::ContainerEq(std::vector{3, 4, 5, 6, 7})); + testing::ContainerEq(std::vector{3, 4, 5, 6, 7})); EXPECT_THAT(TimestampValues(output_packets[2]), - testing::ContainerEq(std::vector{4, 5, 6, 7})); + testing::ContainerEq(std::vector{4, 5, 6, 7})); } else { EXPECT_THAT(TimestampValues(output_packets[0]), - testing::ContainerEq(std::vector{5, 6})); + testing::ContainerEq(std::vector{5, 6})); EXPECT_THAT(TimestampValues(output_packets[1]), - testing::ContainerEq(std::vector{5, 6, 7})); + testing::ContainerEq(std::vector{5, 6, 7})); EXPECT_THAT(TimestampValues(output_packets[2]), - testing::ContainerEq(std::vector{5, 6, 7})); + testing::ContainerEq(std::vector{5, 6, 7})); } } diff --git a/mediapipe/framework/tool/options_field_util.cc b/mediapipe/framework/tool/options_field_util.cc index 308932d4f..248028c25 100644 --- a/mediapipe/framework/tool/options_field_util.cc +++ b/mediapipe/framework/tool/options_field_util.cc @@ -27,10 +27,6 @@ namespace options_field_util { using ::mediapipe::proto_ns::internal::WireFormatLite; using FieldType = WireFormatLite::FieldType; -using ::mediapipe::proto_ns::io::ArrayInputStream; -using ::mediapipe::proto_ns::io::CodedInputStream; -using ::mediapipe::proto_ns::io::CodedOutputStream; -using ::mediapipe::proto_ns::io::StringOutputStream; // Utility functions for OptionsFieldUtil. namespace { diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 4f2390404..fba0267a8 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -454,8 +454,8 @@ class GlContext : public std::enable_shared_from_this { // Number of glFinish calls completed on the GL thread. // Changes should be guarded by mutex_. However, we use simple atomic // loads for efficiency on the fast path. - std::atomic gl_finish_count_ = ATOMIC_VAR_INIT(0); - std::atomic gl_finish_count_target_ = ATOMIC_VAR_INIT(0); + std::atomic gl_finish_count_ = 0; + std::atomic gl_finish_count_target_ = 0; GlContext* context_waiting_on_ ABSL_GUARDED_BY(mutex_) = nullptr; diff --git a/mediapipe/gpu/gl_context_webgl.cc b/mediapipe/gpu/gl_context_webgl.cc index b1f5295c9..25cbed83d 100644 --- a/mediapipe/gpu/gl_context_webgl.cc +++ b/mediapipe/gpu/gl_context_webgl.cc @@ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal( // TODO: Investigate this option in more detail, esp. on Safari. attrs.preserveDrawingBuffer = 0; - // Since the Emscripten canvas target finding function is visible from here, - // we hijack findCanvasEventTarget directly for enforcing old Module.canvas - // behavior if the user desires, falling back to the new DOM element CSS - // selector behavior next if that is specified, and finally just allowing the - // lookup to proceed on a null target. - // TODO: Ensure this works with all options (in particular, - // multithreading options, like the special-case combination of USE_PTHREADS - // and OFFSCREEN_FRAMEBUFFER) - // clang-format off - EM_ASM( - let init_once = true; - if (init_once) { - const cachedFindCanvasEventTarget = findCanvasEventTarget; - - if (typeof cachedFindCanvasEventTarget !== 'function') { - if (typeof console !== 'undefined') { - console.error('Expected Emscripten global function ' - + '"findCanvasEventTarget" not found. WebGL context creation ' - + 'may fail.'); - } - return; - } - - findCanvasEventTarget = function(target) { - if (target == 0) { - if (Module && Module.canvas) { - return Module.canvas; - } else if (Module && Module.canvasCssSelector) { - return cachedFindCanvasEventTarget(Module.canvasCssSelector); - } - if (typeof console !== 'undefined') { - console.warn('Module properties canvas and canvasCssSelector not ' + - 'found during WebGL context creation.'); - } - } - // We still go through with the find attempt, although for most use - // cases it will not succeed, just in case the user does want to fall- - // back. - return cachedFindCanvasEventTarget(target); - }; // NOLINT: Necessary semicolon. - init_once = false; - } - ); - // clang-format on - + // Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also + // looks for our #canvas target in Module.canvas, where we expect it to be. + // -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new + // event target behavior, but it was never supposed to be tapping into our + // canvas anyways. See b/278155946 for more background. + EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; }); EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle = - emscripten_webgl_create_context(nullptr, &attrs); + emscripten_webgl_create_context("#canvas", &attrs); // Check for failure if (context_handle <= 0) { diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index 69b9889c7..f1497f741 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -64,7 +64,7 @@ std::unique_ptr GlTextureBuffer::Create( int actual_ws = image_frame.WidthStep(); int alignment = 0; std::unique_ptr temp; - const uint8* data = image_frame.PixelData(); + const uint8_t* data = image_frame.PixelData(); // Let's see if the pixel data is tightly aligned to one of the alignments // supported by OpenGL, preferring 4 if possible since it's the default. diff --git a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc index c7acd1340..4b0913b96 100644 --- a/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc +++ b/mediapipe/gpu/gpu_buffer_storage_yuv_image.cc @@ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, GpuBufferFormat format) { libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format); int y_stride = std::ceil(1.0f * width / kDefaultDataAligment); - auto y_data = std::make_unique(y_stride * height); + auto y_data = std::make_unique(y_stride * height); switch (fourcc) { case libyuv::FOURCC_NV12: case libyuv::FOURCC_NV21: { @@ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, int uv_width = 2 * std::ceil(0.5f * width); int uv_height = std::ceil(0.5f * height); int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); - auto uv_data = std::make_unique(uv_stride * uv_height); + auto uv_data = std::make_unique(uv_stride * uv_height); yuv_image_ = std::make_shared( fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride, nullptr, 0, width, height); @@ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, int uv_width = std::ceil(0.5f * width); int uv_height = std::ceil(0.5f * height); int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); - auto u_data = std::make_unique(uv_stride * uv_height); - auto v_data = std::make_unique(uv_stride * uv_height); + auto u_data = std::make_unique(uv_stride * uv_height); + auto v_data = std::make_unique(uv_stride * uv_height); yuv_image_ = std::make_shared( fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride, std::move(v_data), uv_stride, width, height); diff --git a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py index 0a7e7a0e0..6aa68a284 100644 --- a/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py +++ b/mediapipe/model_maker/python/text/text_classifier/text_classifier_test.py @@ -16,6 +16,7 @@ import csv import filecmp import os import tempfile +import unittest from unittest import mock as unittest_mock import tensorflow as tf @@ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier from mediapipe.tasks.python.test import test_utils +@unittest.skip('b/275624089') class TextClassifierTest(tf.test.TestCase): _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( diff --git a/mediapipe/model_maker/python/vision/object_detector/BUILD b/mediapipe/model_maker/python/vision/object_detector/BUILD index f3d4407d8..b97d215da 100644 --- a/mediapipe/model_maker/python/vision/object_detector/BUILD +++ b/mediapipe/model_maker/python/vision/object_detector/BUILD @@ -175,11 +175,7 @@ py_test( data = [":testdata"], tags = ["requires-net:external"], deps = [ - ":dataset", - ":hyperparameters", - ":model_spec", - ":object_detector", - ":object_detector_options", + ":object_detector_import", "//mediapipe/tasks/python/test:test_utils", ], ) diff --git a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py index df6b58a07..02f773e69 100644 --- a/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py +++ b/mediapipe/model_maker/python/vision/object_detector/object_detector_test.py @@ -19,11 +19,7 @@ from unittest import mock as unittest_mock from absl.testing import parameterized import tensorflow as tf -from mediapipe.model_maker.python.vision.object_detector import dataset -from mediapipe.model_maker.python.vision.object_detector import hyperparameters -from mediapipe.model_maker.python.vision.object_detector import model_spec as ms -from mediapipe.model_maker.python.vision.object_detector import object_detector -from mediapipe.model_maker.python.vision.object_detector import object_detector_options +from mediapipe.model_maker.python.vision import object_detector from mediapipe.tasks.python.test import test_utils as task_test_utils @@ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): super().setUp() dataset_folder = task_test_utils.get_test_data_path('coco_data') cache_dir = self.create_tempdir() - self.data = dataset.Dataset.from_coco_folder( + self.data = object_detector.Dataset.from_coco_folder( dataset_folder, cache_dir=cache_dir ) # Mock tempfile.gettempdir() to be unique for each test to avoid race @@ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.addCleanup(mock_gettempdir.stop) def test_object_detector(self): - hparams = hyperparameters.HParams( + hparams = object_detector.HParams( epochs=1, batch_size=2, learning_rate=0.9, shuffle=False, export_dir=self.create_tempdir(), ) - options = object_detector_options.ObjectDetectorOptions( - supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams + options = object_detector.ObjectDetectorOptions( + supported_model=object_detector.SupportedModels.MOBILENET_V2, + hparams=hparams, ) # Test `create`` model = object_detector.ObjectDetector.create( @@ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): self.assertGreater(os.path.getsize(output_metadata_file), 0) # Test `quantization_aware_training` - qat_hparams = hyperparameters.QATHParams( + qat_hparams = object_detector.QATHParams( learning_rate=0.9, batch_size=2, epochs=1, diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index 14cea526f..2e33ebf6c 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -298,6 +298,7 @@ cc_library( ":tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/memory", diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc index 74678804f..c2bc413c5 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc @@ -91,8 +91,8 @@ absl::Status FrameAnnotationToTimedBoxListCalculator::Process( TimedBoxProto* added_box = output_objects->add_box(); ComputeBoundingRect(key_points, added_box); added_box->set_id(annotation.object_id()); - const int64 time_msec = - static_cast(std::round(frame_annotation.timestamp() / 1000)); + const int64_t time_msec = + static_cast(std::round(frame_annotation.timestamp() / 1000)); added_box->set_time_msec(time_msec); } diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc b/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc index eebf88579..1685a4f68 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_tracker.cc @@ -24,8 +24,8 @@ namespace mediapipe { void FrameAnnotationTracker::AddDetectionResult( const FrameAnnotation& frame_annotation) { - const int64 time_us = - static_cast(std::round(frame_annotation.timestamp())); + const int64_t time_us = + static_cast(std::round(frame_annotation.timestamp())); for (const auto& object_annotation : frame_annotation.annotations()) { detected_objects_[time_us + object_annotation.object_id()] = object_annotation; @@ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult( absl::flat_hash_set* cancel_object_ids) { CHECK(cancel_object_ids != nullptr); FrameAnnotation frame_annotation; - std::vector keys_to_be_deleted; + std::vector keys_to_be_deleted; for (const auto& detected_obj : detected_objects_) { const int object_id = detected_obj.second.object_id(); if (cancel_object_ids->contains(object_id)) { diff --git a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc index e3686f65e..d74b59a25 100644 --- a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc +++ b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc @@ -76,7 +76,7 @@ class TfLiteTensorsToObjectsCalculator : public CalculatorBase { // In a single MediaPipe session, the IDs are unique. // Also assign timestamp for the FrameAnnotation to be the input packet // timestamp. - void AssignObjectIdAndTimestamp(int64 timestamp_us, + void AssignObjectIdAndTimestamp(int64_t timestamp_us, FrameAnnotation* annotation); int num_classes_ = 0; @@ -207,7 +207,7 @@ void TfLiteTensorsToObjectsCalculator::Project3DTo2D( } void TfLiteTensorsToObjectsCalculator::AssignObjectIdAndTimestamp( - int64 timestamp_us, FrameAnnotation* annotation) { + int64_t timestamp_us, FrameAnnotation* annotation) { for (auto& ann : *annotation->mutable_annotations()) { ann.set_object_id(GetNextObjectId()); } diff --git a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc index 108374003..5f5f8da6c 100644 --- a/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc +++ b/mediapipe/tasks/cc/audio/audio_classifier/audio_classifier_test.cc @@ -37,7 +37,7 @@ limitations under the License. #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/containers/category.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -157,7 +157,7 @@ void CheckStreamingModeResults(std::vector outputs) { } } -class CreateFromOptionsTest : public tflite_shims::testing::Test {}; +class CreateFromOptionsTest : public tflite::testing::Test {}; TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) { auto options = std::make_unique(); @@ -270,7 +270,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); } -class ClassifyTest : public tflite_shims::testing::Test {}; +class ClassifyTest : public tflite::testing::Test {}; TEST_F(ClassifyTest, Succeeds) { auto audio_buffer = GetAudioData(k16kTestWavFilename); @@ -467,7 +467,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { } } -class ClassifyAsyncTest : public tflite_shims::testing::Test {}; +class ClassifyAsyncTest : public tflite::testing::Test {}; TEST_F(ClassifyAsyncTest, Succeeds) { constexpr int kSampleRateHz = 48000; diff --git a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc index e388423b1..81ecb1237 100644 --- a/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc +++ b/mediapipe/tasks/cc/audio/audio_embedder/audio_embedder_test.cc @@ -36,7 +36,7 @@ limitations under the License. #include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -66,7 +66,7 @@ Matrix GetAudioData(absl::string_view filename) { return matrix_mapping.matrix(); } -class CreateFromOptionsTest : public tflite_shims::testing::Test {}; +class CreateFromOptionsTest : public tflite::testing::Test {}; TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { auto audio_embedder = @@ -124,7 +124,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) { MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); } -class EmbedTest : public tflite_shims::testing::Test {}; +class EmbedTest : public tflite::testing::Test {}; TEST_F(EmbedTest, SucceedsWithSilentAudio) { auto options = std::make_unique(); @@ -187,7 +187,7 @@ TEST_F(EmbedTest, SucceedsWithDifferentAudios) { MP_EXPECT_OK(audio_embedder->Close()); } -class EmbedAsyncTest : public tflite_shims::testing::Test { +class EmbedAsyncTest : public tflite::testing::Test { protected: void RunAudioEmbedderInStreamMode(std::string audio_file_name, int sample_rate_hz, diff --git a/mediapipe/tasks/cc/audio/utils/BUILD b/mediapipe/tasks/cc/audio/utils/BUILD index 1d6988008..29d88d33d 100644 --- a/mediapipe/tasks/cc/audio/utils/BUILD +++ b/mediapipe/tasks/cc/audio/utils/BUILD @@ -47,7 +47,7 @@ cc_test_with_tflite( data = ["//mediapipe/tasks/testdata/audio:test_models"], tflite_deps = [ "//mediapipe/tasks/cc/core:model_resources", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], deps = [ ":audio_tensor_specs", diff --git a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc index 60b2bdc50..4f7a5000e 100644 --- a/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc +++ b/mediapipe/tasks/cc/audio/utils/audio_tensor_specs_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -52,7 +52,7 @@ constexpr char kModelWithMetadata[] = "yamnet_audio_classifier_with_metadata.tflite"; constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite"; -class AudioTensorSpecsTest : public tflite_shims::testing::Test {}; +class AudioTensorSpecsTest : public tflite::testing::Test {}; TEST_F(AudioTensorSpecsTest, BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) { diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index bb49bdb9d..e447f5d72 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -63,7 +63,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], ) @@ -232,6 +232,6 @@ cc_test( "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], ) diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc index 811d70544..c824919df 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace { @@ -66,8 +66,7 @@ ClassificationList MakeClassificationList(int class_index) { class_index)); } -class ClassificationAggregationCalculatorTest - : public tflite_shims::testing::Test { +class ClassificationAggregationCalculatorTest : public tflite::testing::Test { protected: absl::StatusOr BuildGraph( bool connect_timestamps = false) { diff --git a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc index f2b2fa1d5..c4b635d24 100644 --- a/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc +++ b/mediapipe/tasks/cc/components/calculators/embedding_aggregation_calculator_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace { @@ -52,7 +52,7 @@ constexpr char kTimestampsName[] = "timestamps_in"; constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out"; -class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test { +class EmbeddingAggregationCalculatorTest : public tflite::testing::Test { protected: absl::StatusOr BuildGraph(bool connect_timestamps) { Graph graph; diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index cfb3b02cf..9f796920c 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -66,7 +66,7 @@ using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::tflite::ProcessUnit; using ::tflite::TensorMetadata; -using LabelItems = mediapipe::proto_ns::Map; +using LabelItems = mediapipe::proto_ns::Map; using TensorsSource = mediapipe::api2::builder::Source>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index a11bad71a..a61ffa6b1 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -49,7 +49,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/util/label_map.pb.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -101,7 +101,7 @@ absl::StatusOr> CreateModelResourcesForModel( std::move(external_file)); } -class ConfigureTest : public tflite_shims::testing::Test {}; +class ConfigureTest : public tflite::testing::Test {}; TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { MP_ASSERT_OK_AND_ASSIGN( @@ -417,7 +417,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { )pb"))); } -class PostprocessingTest : public tflite_shims::testing::Test { +class PostprocessingTest : public tflite::testing::Test { protected: absl::StatusOr BuildGraph( absl::string_view model_name, const proto::ClassifierOptions& options, @@ -520,7 +520,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { auto poller, BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); + std::vector tensor(kMobileNetNumClasses, 0); tensor[1] = 18; tensor[2] = 16; @@ -552,7 +552,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); + std::vector tensor(kMobileNetNumClasses, 0); tensor[1] = 12; tensor[2] = 14; tensor[3] = 16; @@ -589,7 +589,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { auto poller, BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); + std::vector tensor(kMobileNetNumClasses, 0); tensor[1] = 12; tensor[2] = 14; tensor[3] = 16; @@ -677,11 +677,11 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, /*connect_timestamps=*/true)); // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); + std::vector tensor_0(kMobileNetNumClasses, 0); tensor_0[1] = 12; tensor_0[2] = 14; tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); + std::vector tensor_1(kMobileNetNumClasses, 0); tensor_1[5] = 12; tensor_1[6] = 14; tensor_1[7] = 16; diff --git a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc index 768508446..94a2a7f3f 100644 --- a/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/embedding_postprocessing_graph_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -86,7 +86,7 @@ absl::StatusOr> CreateModelResourcesForModel( std::move(external_file)); } -class ConfigureTest : public tflite_shims::testing::Test {}; +class ConfigureTest : public tflite::testing::Test {}; TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( @@ -153,7 +153,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { has_quantized_outputs: false)pb"))); } -class PostprocessingTest : public tflite_shims::testing::Test { +class PostprocessingTest : public tflite::testing::Test { protected: absl::StatusOr BuildGraph( absl::string_view model_name, const proto::EmbedderOptions& options, diff --git a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc index 6c094c6bc..c69a51a65 100644 --- a/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/image_preprocessing_graph_test.cc @@ -37,7 +37,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -125,7 +125,7 @@ absl::StatusOr> CreateTaskRunner( return TaskRunner::Create(graph.GetConfig()); } -class ConfigureTest : public tflite_shims::testing::Test {}; +class ConfigureTest : public tflite::testing::Test {}; TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { MP_ASSERT_OK_AND_ASSIGN( diff --git a/mediapipe/tasks/cc/core/BUILD b/mediapipe/tasks/cc/core/BUILD index 465c382bb..95cfdd15e 100644 --- a/mediapipe/tasks/cc/core/BUILD +++ b/mediapipe/tasks/cc/core/BUILD @@ -78,6 +78,7 @@ cc_library( hdrs = ["mediapipe_builtin_op_resolver.h"], deps = [ "//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite", + "//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite", "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", @@ -128,9 +129,9 @@ cc_library_with_tflite( srcs = ["model_resources.cc"], hdrs = ["model_resources.h"], tflite_deps = [ - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", - "@org_tensorflow//tensorflow/lite/core/shims:verifier", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/tools:verifier", ], deps = [ ":external_file_handler", @@ -159,9 +160,9 @@ cc_test_with_tflite( ], tflite_deps = [ ":model_resources", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite:test_util", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], deps = [ ":utils", @@ -186,7 +187,7 @@ cc_library_with_tflite( hdrs = ["model_resources_cache.h"], tflite_deps = [ ":model_resources", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], deps = [ ":model_asset_bundle_resources", @@ -233,7 +234,7 @@ cc_test_with_tflite( ":model_resources", ":model_resources_cache", ":model_resources_calculator", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], deps = [ "//mediapipe/framework/port:gtest_main", @@ -284,7 +285,7 @@ cc_test_with_tflite( ":task_runner", ":model_resources", ":model_resources_cache", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], deps = [ "//mediapipe/calculators/core:pass_through_calculator", @@ -317,6 +318,9 @@ cc_library( ":model_resources", ":task_runner", ":utils", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework/port:requires", + "//mediapipe/framework/port:status", "//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", diff --git a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc index ae64e33ef..80097fd09 100644 --- a/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc +++ b/mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" @@ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { AddCustom("KmeansEmbeddingLookup", mediapipe::tflite_operations::Register_KmeansEmbeddingLookup()); // For the UniversalSentenceEncoder model. + AddCustom("TFSentencepieceTokenizeOp", + mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER()); AddCustom("RaggedTensorToTensor", mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR()); } diff --git a/mediapipe/tasks/cc/core/model_resources.cc b/mediapipe/tasks/cc/core/model_resources.cc index 7819f6213..76695125a 100644 --- a/mediapipe/tasks/cc/core/model_resources.cc +++ b/mediapipe/tasks/cc/core/model_resources.cc @@ -37,8 +37,8 @@ limitations under the License. #include "mediapipe/util/tflite/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/model_builder.h" -#include "tensorflow/lite/core/shims/cc/tools/verifier.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/tools/verifier.h" namespace mediapipe { namespace tasks { @@ -52,7 +52,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor; bool ModelResources::Verifier::Verify(const char* data, int length, tflite::ErrorReporter* reporter) { - return tflite_shims::Verify(data, length, reporter); + return tflite::Verify(data, length, reporter); } ModelResources::ModelResources(const std::string& tag, @@ -124,7 +124,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() { // and that it uses only operators that are supported by the OpResolver // that was passed to the ModelResources constructor, and then builds // the model from the buffer. - auto model = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer( + auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( buffer_data, buffer_size, &verifier_, &error_reporter_); if (model == nullptr) { static constexpr char kInvalidFlatbufferMessage[] = @@ -151,8 +151,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() { } model_packet_ = MakePacket( - model.release(), - [](tflite_shims::FlatBufferModel* model) { delete model; }); + model.release(), [](tflite::FlatBufferModel* model) { delete model; }); ASSIGN_OR_RETURN(auto model_metadata_extractor, metadata::ModelMetadataExtractor::CreateFromModelBuffer( buffer_data, buffer_size)); diff --git a/mediapipe/tasks/cc/core/model_resources.h b/mediapipe/tasks/cc/core/model_resources.h index c2a03f1f2..1bc1b65eb 100644 --- a/mediapipe/tasks/cc/core/model_resources.h +++ b/mediapipe/tasks/cc/core/model_resources.h @@ -32,10 +32,10 @@ limitations under the License. #include "mediapipe/util/tflite/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/kernels/register.h" -#include "tensorflow/lite/core/shims/cc/model.h" -#include "tensorflow/lite/core/shims/cc/model_builder.h" -#include "tensorflow/lite/core/shims/cc/tools/verifier.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/tools/verifier.h" namespace mediapipe { namespace tasks { @@ -51,8 +51,8 @@ class ModelResources { public: // Represents a TfLite model as a FlatBuffer. using ModelPtr = - std::unique_ptr>; + std::unique_ptr>; // Takes the ownership of the provided ExternalFile proto and creates // ModelResources from the proto and an op resolver object. A non-empty tag @@ -61,7 +61,7 @@ class ModelResources { static absl::StatusOr> Create( const std::string& tag, std::unique_ptr model_file, std::unique_ptr op_resolver = - absl::make_unique()); + absl::make_unique()); // Takes the ownership of the provided ExternalFile proto and creates // ModelResources from the proto and an op resolver mediapipe packet. A diff --git a/mediapipe/tasks/cc/core/model_resources_calculator_test.cc b/mediapipe/tasks/cc/core/model_resources_calculator_test.cc index 83134a8c7..6ba52e521 100644 --- a/mediapipe/tasks/cc/core/model_resources_calculator_test.cc +++ b/mediapipe/tasks/cc/core/model_resources_calculator_test.cc @@ -30,7 +30,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -124,7 +124,7 @@ void RunGraphWithGraphService(std::unique_ptr model_resources, } // namespace -class ModelResourcesCalculatorTest : public tflite_shims::testing::Test {}; +class ModelResourcesCalculatorTest : public tflite::testing::Test {}; TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) { auto graph_config = ParseTextProtoOrDie( diff --git a/mediapipe/tasks/cc/core/model_resources_test.cc b/mediapipe/tasks/cc/core/model_resources_test.cc index 3bc5ff062..036d0e784 100644 --- a/mediapipe/tasks/cc/core/model_resources_test.cc +++ b/mediapipe/tasks/cc/core/model_resources_test.cc @@ -38,9 +38,9 @@ limitations under the License. #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/kernels/builtin_op_kernels.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/test_util.h" namespace tflite { namespace ops { @@ -116,7 +116,7 @@ void CheckModelResourcesPackets(const ModelResources* model_resources) { } // namespace -class ModelResourcesTest : public tflite_shims::testing::Test {}; +class ModelResourcesTest : public tflite::testing::Test {}; TEST_F(ModelResourcesTest, CreateFromBinaryContent) { auto model_file = std::make_unique(); @@ -211,7 +211,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsFromFile) { static constexpr char kCustomOpName[] = "MY_CUSTOM_OP"; tflite::MutableOpResolver resolver; resolver.AddBuiltin(::tflite::BuiltinOperator_ADD, - ::tflite_shims::ops::builtin::Register_ADD()); + ::tflite::ops::builtin::Register_ADD()); resolver.AddCustom(kCustomOpName, ::tflite::ops::custom::Register_MY_CUSTOM_OP()); @@ -275,7 +275,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsPacket) { static constexpr char kCustomOpName[] = "MY_CUSTOM_OP"; tflite::MutableOpResolver resolver; resolver.AddBuiltin(::tflite::BuiltinOperator_ADD, - ::tflite_shims::ops::builtin::Register_ADD()); + ::tflite::ops::builtin::Register_ADD()); resolver.AddCustom(kCustomOpName, ::tflite::ops::custom::Register_MY_CUSTOM_OP()); diff --git a/mediapipe/tasks/cc/core/task_api_factory.h b/mediapipe/tasks/cc/core/task_api_factory.h index 631696b4c..83c2f3207 100644 --- a/mediapipe/tasks/cc/core/task_api_factory.h +++ b/mediapipe/tasks/cc/core/task_api_factory.h @@ -23,7 +23,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/port/requires.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/core/base_task_api.h" #include "mediapipe/tasks/cc/core/model_resources.h" @@ -54,6 +58,8 @@ class TaskApiFactory { std::unique_ptr resolver, PacketsCallback packets_callback = nullptr) { bool found_task_subgraph = false; + // This for-loop ensures there's only one subgraph besides + // FlowLimiterCalculator. for (const auto& node : graph_config.node()) { if (node.calculator() == "FlowLimiterCalculator") { continue; @@ -64,13 +70,7 @@ class TaskApiFactory { "Task graph config should only contain one task subgraph node.", MediaPipeTasksStatus::kInvalidTaskGraphConfigError); } else { - if (!node.options().HasExtension(Options::ext)) { - return CreateStatusWithPayload( - absl::StatusCode::kInvalidArgument, - absl::StrCat(node.calculator(), - " is missing the required task options field."), - MediaPipeTasksStatus::kInvalidTaskGraphConfigError); - } + MP_RETURN_IF_ERROR(CheckHasValidOptions(node)); found_task_subgraph = true; } } @@ -80,6 +80,35 @@ class TaskApiFactory { std::move(packets_callback))); return std::make_unique(std::move(runner)); } + + private: + template + static absl::Status CheckHasValidOptions( + const CalculatorGraphConfig::Node& node) { + if constexpr (mediapipe::Requires( + [](auto&& o) -> decltype(o.ext) {})) { + if (node.options().HasExtension(Options::ext)) { + return absl::OkStatus(); + } + } else { +#ifndef MEDIAPIPE_PROTO_LITE + for (const auto& option : node.node_options()) { + if (absl::StrContains(option.type_url(), + Options::descriptor()->full_name())) { + return absl::OkStatus(); + } + } +#else // MEDIAPIPE_PROTO_LITE + // Skip the check for proto lite, as Options::descriptor() is unavailable. + return absl::OkStatus(); +#endif // MEDIAPIPE_PROTO_LITE + } + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat(node.calculator(), + " is missing the required task options field."), + MediaPipeTasksStatus::kInvalidTaskGraphConfigError); + } }; } // namespace core diff --git a/mediapipe/tasks/cc/core/task_runner_test.cc b/mediapipe/tasks/cc/core/task_runner_test.cc index fdd32eec4..75c6260af 100644 --- a/mediapipe/tasks/cc/core/task_runner_test.cc +++ b/mediapipe/tasks/cc/core/task_runner_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -112,7 +112,7 @@ CalculatorGraphConfig GetModelSidePacketsToStreamPacketsGraphConfig( } // namespace -class TaskRunnerTest : public tflite_shims::testing::Test {}; +class TaskRunnerTest : public tflite::testing::Test {}; TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) { CalculatorGraphConfig proto = ParseTextProtoOrDie(R"pb( diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD new file mode 100644 index 000000000..19f843c4e --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD @@ -0,0 +1,172 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), +) + +filegroup( + name = "config_fbs", + srcs = ["config.fbs"], +) + +flatbuffer_cc_library( + name = "config", + srcs = [ + "config.fbs", + ], +) + +flatbuffer_cc_library( + name = "encoder_config", + srcs = [ + "encoder_config.fbs", + ], + includes = [":config_fbs"], +) + +cc_library( + name = "utils", + hdrs = [ + "utils.h", + ], +) + +cc_library( + name = "double_array_trie", + hdrs = [ + "double_array_trie.h", + ], + deps = [ + ":config", + ":utils", + ], +) + +cc_library( + name = "double_array_trie_builder", + srcs = [ + "double_array_trie_builder.cc", + ], + hdrs = [ + "double_array_trie_builder.h", + ], + deps = ["@darts_clone"], +) + +cc_test( + name = "double_array_trie_test", + srcs = [ + "double_array_trie_test.cc", + ], + deps = [ + ":double_array_trie", + ":double_array_trie_builder", + ":encoder_config", + ":utils", + "//mediapipe/framework/port:gtest_main", + ], +) + +cc_library( + name = "sentencepiece_constants", + hdrs = ["sentencepiece_constants.h"], +) + +cc_library( + name = "model_converter", + srcs = [ + "model_converter.cc", + ], + hdrs = [ + "model_converter.h", + ], + deps = [ + ":config", + ":double_array_trie_builder", + ":encoder_config", + ":sentencepiece_constants", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_sentencepiece//src:sentencepiece_model_cc_proto", + ], +) + +cc_library( + name = "optimized_encoder", + srcs = [ + "optimized_encoder.cc", + ], + hdrs = [ + "optimized_encoder.h", + ], + deps = [ + ":double_array_trie", + ":encoder_config", + ":utils", + ], +) + +cc_library( + name = "sentencepiece_tokenizer_tflite", + srcs = ["sentencepiece_tokenizer_tflite.cc"], + hdrs = ["sentencepiece_tokenizer_tflite.h"], + visibility = [ + "//visibility:public", + ], + deps = + [ + ":optimized_encoder", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", + ], +) + +cc_test( + name = "optimized_encoder_test", + srcs = [ + "optimized_encoder_test.cc", + ], + data = [ + ":testdata", + ], + deps = [ + ":double_array_trie_builder", + ":encoder_config", + ":model_converter", + ":optimized_encoder", + "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@com_google_sentencepiece//src:sentencepiece_cc_proto", + "@com_google_sentencepiece//src:sentencepiece_processor", + "@org_tensorflow//tensorflow/core:lib", + ], +) diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs new file mode 100644 index 000000000..16408ffee --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs @@ -0,0 +1,25 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +namespace mediapipe.tflite_operations.sentencepiece; + +table Trie { + nodes: [uint32]; +} + + +enum EncoderVersion: byte { + SENTENCE_PIECE = 0, +} diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h new file mode 100644 index 000000000..c3b568f1c --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h @@ -0,0 +1,111 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +// A trie node specifies a node in the tree, either an intermediate node or +// a leaf node. +// A leaf node contains the id as an int of the string match. This id is encoded +// in the lower 31 bits, thus the number of distinct ids is 2^31. +// An intermediate node has an associated label and an offset to its children. +// The label is encoded in the least significant byte and must match the input +// character during matching. + +// A memory mappable trie, compatible with Darts::DoubleArray. +class DoubleArrayTrie { + public: + struct Match { + Match() {} + Match(int id, int match_length) : id(id), match_length(match_length) {} + int id = -1; + int match_length = -1; + bool empty() const { return match_length == -1; } + bool operator==(const Match& m) const { + return m.id == id && m.match_length == match_length; + } + }; + + // nodes and nodes_length specify the array of the nodes of the trie. + explicit DoubleArrayTrie(const flatbuffers::Vector* nodes) + : nodes_(nodes) {} + + // Finds matches that are prefixes of a string. + template + void IteratePrefixMatches(const utils::string_view& input, + callback update_fn) const; + + // Finds the longest prefix match of a string. + Match LongestPrefixMatch(const utils::string_view& input) const { + Match match; + IteratePrefixMatches(input, [&match](const Match& m) { match = m; }); + return match; + } + + private: + // Returns whether a node as a leaf as a child. + bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; } + + // Returns a value associated with a node. Available when a node is a leaf. + int value(uint32_t i) const { + return static_cast(((*nodes_)[i]) & 0x7fffffff); + } + + // Returns a label associated with a node. + // A leaf node will have the MSB set and thus return an invalid label. + int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; } + + // Returns offset to children. + int32_t offset(uint32_t i) const { + const uint32_t node = (*nodes_)[i]; + return (node >> 10) << ((node & 0x200) >> 6); + } + + const flatbuffers::Vector* nodes_; +}; + +template +void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input, + callback update_fn) const { + if (nodes_->size() == 0) { + return; + } + uint32_t pos = offset(0); + for (int i = 0; i < input.length(); ++i) { + pos ^= static_cast(input.at(i)); + if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) { + // No match, exit. + return; + } + const bool node_has_leaf = has_leaf(pos); + pos ^= offset(pos); + if (pos < 0 || pos >= nodes_->size()) { + // We can get here only if the structure is corrupted. + return; + } + if (node_has_leaf) { + update_fn(Match(value(pos), i + 1)); + } + } +} + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc new file mode 100644 index 000000000..f492b5c48 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.cc @@ -0,0 +1,75 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" + +#include +#include + +#include "include/darts.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +std::vector BuildTrie(const std::vector& data) { + std::vector ids; + ids.reserve(data.size()); + for (int i = 0; i < data.size(); ++i) { + ids.push_back(i); + } + return BuildTrie(data, ids); +} + +std::vector BuildTrie(const std::vector& data, + const std::vector& ids) { + // We make strong assumptions about binary structure of trie. + struct OneElement { + OneElement(const std::string* key_, int index_) + : key(key_), index(index_) {} + const std::string* key; + int index; + bool operator<(const OneElement& el) const { return *key < *el.key; } + }; + std::vector elements; + elements.reserve(data.size()); + auto data_iterator = std::begin(data); + auto ids_iterator = std::begin(ids); + for (; data_iterator != std::end(data) && ids_iterator != std::end(ids); + ++data_iterator, ++ids_iterator) { + elements.emplace_back(&(*data_iterator), *ids_iterator); + } + // Sort by keys. + std::sort(elements.begin(), elements.end()); + + // Create vectors to build the trie. + std::vector strings; + std::vector indexes; + strings.reserve(data.size()); + indexes.reserve(data.size()); + for (const auto& el : elements) { + strings.push_back(el.key->c_str()); + indexes.push_back(el.index); + } + auto trie = std::make_unique(); + trie->build(data.size(), const_cast(&strings[0]), nullptr, + &indexes[0]); + // We make strong assumptions about internal Darts trie structure: + // - it is a vector of 32 bit signed integers + // - the "array" is the only one structure that contains all information about + // the trie. + const uint32_t* trie_data = static_cast(trie->array()); + return std::vector(trie_data, trie_data + trie->size()); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h new file mode 100644 index 000000000..94c50bffc --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h @@ -0,0 +1,32 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ + +#include +#include + +namespace mediapipe::tflite_operations::sentencepiece { + +std::vector BuildTrie(const std::vector& data, + const std::vector& ids); + +// A variant where ids are indexes in data. +std::vector BuildTrie(const std::vector& data); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc new file mode 100644 index 000000000..60a78e126 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_test.cc @@ -0,0 +1,73 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h" + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +TEST(DoubleArrayTrieTest, Match) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector test_strings = {"A", "AAX", "AA", "B"}; + const auto trie_vector = builder.CreateVector(BuildTrie(test_strings)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto pieces = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_pieces(pieces); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + DoubleArrayTrie dat(config->pieces()->nodes()); + EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")), + DoubleArrayTrie::Match(2, 2)); + + std::vector matches; + dat.IteratePrefixMatches( + utils::string_view("AAXL"), + [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); + EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1), + DoubleArrayTrie::Match(2, 2), + DoubleArrayTrie::Match(1, 3))); +} + +TEST(DoubleArrayTrieTest, ComplexMatch) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector test_strings = {"\xe2\x96\x81the", ",", "s", + "\xe2\x96\x81Hello"}; + const std::vector test_ids = {0, 5, 10, 15}; + const auto trie_vector = + builder.CreateVector(BuildTrie(test_strings, test_ids)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto pieces = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_pieces(pieces); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + DoubleArrayTrie dat(config->pieces()->nodes()); + + std::vector matches; + dat.IteratePrefixMatches( + utils::string_view("\xe2\x96\x81Hello"), + [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); + EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8))); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs new file mode 100644 index 000000000..2e7836803 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config.fbs @@ -0,0 +1,52 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +include "config.fbs"; + +namespace mediapipe.tflite_operations.sentencepiece; + +table EncoderConfig { + // Version of the encoder. + version: EncoderVersion = SENTENCE_PIECE; + start_code: int32 = 0; + end_code: int32 = 0; + + unknown_code: int32 = -1; + // Weight of "unknown code" when encoding. "Penalty" because it usually has a + // big negative weight,less than any other sentencepiece. + unknown_penalty: float = 0; + + // The offset for encoding, usually used when codes with low codes are reserved + // for some special needs. + encoding_offset: int32; + + // String pieces for encoding. + pieces: Trie; + pieces_scores: [float]; + + // Normalization related parameters. + remove_extra_whitespaces: bool; + + // Add a whitespace prefix before encoding. + add_dummy_prefix: bool; + + // Escape whitespaces during encoding so the decoder can restore them exactly as + // in the input. + escape_whitespaces: bool; + + // Normalization parameters. + normalized_prefixes: Trie; + normalized_replacements: [byte]; +} + +root_type EncoderConfig; diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc new file mode 100644 index 000000000..3a831f3d7 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h" +#include "src/sentencepiece_model.pb.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +std::tuple, std::vector> +DecodePrecompiledCharsmap( + const ::sentencepiece::NormalizerSpec& normalizer_spec) { + // This function "undoes" encoding done by + // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap. + const char* precompiled_map = normalizer_spec.precompiled_charsmap().data(); + const uint32_t trie_size = + *reinterpret_cast(precompiled_map); + const uint32_t* trie_ptr = + reinterpret_cast(precompiled_map + sizeof(uint32_t)); + const int8_t* normalized_ptr = reinterpret_cast( + precompiled_map + sizeof(uint32_t) + trie_size); + const int normalized_size = normalizer_spec.precompiled_charsmap().length() - + sizeof(uint32_t) - trie_size; + return std::make_tuple( + std::vector(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)), + std::vector(normalized_ptr, normalized_ptr + normalized_size)); +} + +absl::StatusOr ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset) { + ::sentencepiece::ModelProto model_config; + if (!model_config.ParseFromString(model_config_str)) { + return absl::InvalidArgumentError( + "Invalid configuration, can't parse SentencePiece model config " + + model_config.InitializationErrorString()); + } + // Convert sentencepieces. + std::vector pieces; + pieces.reserve(model_config.pieces_size()); + std::vector scores; + scores.reserve(model_config.pieces_size()); + std::vector ids; + ids.reserve(model_config.pieces_size()); + float min_score = 0.0; + int index = 0; + for (const auto& piece : model_config.pieces()) { + switch (piece.type()) { + case ::sentencepiece::ModelProto::SentencePiece::NORMAL: + case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: + pieces.push_back(piece.piece()); + ids.push_back(index); + if (piece.score() < min_score) { + min_score = piece.score(); + } + break; + case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: + case ::sentencepiece::ModelProto::SentencePiece::CONTROL: + // Ignore unknown and control codes. + break; + default: + return absl::InvalidArgumentError("Invalid SentencePiece piece type " + + piece.piece()); + } + scores.push_back(piece.score()); + ++index; + } + flatbuffers::FlatBufferBuilder builder(1024); + const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids)); + const auto pieces_score_vector = builder.CreateVector(scores); + TrieBuilder pieces_trie_builder(builder); + pieces_trie_builder.add_nodes(pieces_trie_vector); + const auto pieces_trie_fbs = pieces_trie_builder.Finish(); + + // Converting normalization. + const auto normalization = + DecodePrecompiledCharsmap(model_config.normalizer_spec()); + const auto normalization_trie = std::get<0>(normalization); + const auto normalization_strings = std::get<1>(normalization); + const auto normalization_trie_vector = + builder.CreateVector(normalization_trie); + TrieBuilder normalization_trie_builder(builder); + normalization_trie_builder.add_nodes(normalization_trie_vector); + const auto normalization_trie_fbs = normalization_trie_builder.Finish(); + const auto normalization_strings_fbs = + builder.CreateVector(normalization_strings); + + EncoderConfigBuilder ecb(builder); + ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); + ecb.add_start_code(model_config.trainer_spec().bos_id()); + ecb.add_end_code(model_config.trainer_spec().eos_id()); + ecb.add_unknown_code(model_config.trainer_spec().unk_id()); + ecb.add_unknown_penalty(min_score - kUnkPenalty); + ecb.add_encoding_offset(encoding_offset); + ecb.add_pieces(pieces_trie_fbs); + ecb.add_pieces_scores(pieces_score_vector); + ecb.add_remove_extra_whitespaces( + model_config.normalizer_spec().remove_extra_whitespaces()); + ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix()); + ecb.add_escape_whitespaces( + model_config.normalizer_spec().escape_whitespaces()); + ecb.add_normalized_prefixes(normalization_trie_fbs); + ecb.add_normalized_replacements(normalization_strings_fbs); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + return std::string(reinterpret_cast(builder.GetBufferPointer()), + builder.GetSize()); +} + +std::string ConvertSentencepieceModel(const std::string& model_string) { + const auto result = ConvertSentencepieceModelToFlatBuffer(model_string); + assert(result.status().ok()); + return result.value(); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h new file mode 100644 index 000000000..828db16da --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h @@ -0,0 +1,33 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ + +#include + +#include "absl/status/statusor.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +// Converts Sentencepiece configuration to flatbuffer format. +// encoding_offset is used by some encoders that combine different encodings. +absl::StatusOr ConvertSentencepieceModelToFlatBuffer( + const std::string& model_config_str, int encoding_offset = 0); +std::string ConvertSentencepieceModel(const std::string& model_string); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc new file mode 100644 index 000000000..365b1a5ad --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.cc @@ -0,0 +1,236 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" + +#include +#include + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" + +namespace mediapipe::tflite_operations::sentencepiece { +namespace { + +const char kSpaceSymbol[] = "\xe2\x96\x81"; + +template +std::tuple> process_string( + const std::string& input, const std::vector& offsets, + const processing_callback& pc) { + std::string result_string; + result_string.reserve(input.size()); + std::vector result_offsets; + result_offsets.reserve(offsets.size()); + for (int i = 0, j = 0; i < input.size();) { + auto result = pc(input.data() + i, input.size() - i); + auto consumed = std::get<0>(result); + auto new_string = std::get<1>(result); + if (consumed == 0) { + // Skip the current byte and move forward. + result_string.push_back(input[i]); + result_offsets.push_back(offsets[j]); + i++; + j++; + continue; + } + result_string.append(new_string.data(), new_string.length()); + for (int i = 0; i < new_string.length(); ++i) { + result_offsets.push_back(offsets[j]); + } + j += consumed; + i += consumed; + } + return std::make_tuple(result_string, result_offsets); +} + +inline char is_whitespace(char c) { + return c == ' ' || c == '\t' || c == '\r' || c == '\n'; +} + +std::tuple remove_extra_whitespaces(const char* data, + int len) { + if (len == 0 || !is_whitespace(*data)) { + return std::make_tuple(0, utils::string_view(nullptr, 0)); + } + int num_consumed = 1; + for (; num_consumed < len && is_whitespace(data[num_consumed]); + ++num_consumed) { + } + return num_consumed > 1 + ? std::make_tuple(num_consumed, utils::string_view(" ", 1)) + : std::make_tuple(0, utils::string_view(nullptr, 0)); +} + +std::tuple find_replacement( + const char* data, int len, const DoubleArrayTrie& dat, + const flatbuffers::Vector& replacements) { + const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); + if (!max_match.empty()) { + // Because flatbuffer byte is signed char which is not the same as char, + // there is the reinterpret_cast here. + const char* replaced_string_ptr = + reinterpret_cast(replacements.data() + max_match.id); + return std::make_tuple(max_match.match_length, + utils::string_view(replaced_string_ptr)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); +} +} // namespace + +std::tuple> NormalizeString( + const std::string& in_string, const EncoderConfig& config) { + std::vector output_offsets; + std::string result = in_string; + output_offsets.reserve(in_string.length()); + for (int i = 0; i < in_string.length(); ++i) { + output_offsets.push_back(i); + } + if (in_string.empty()) { + return std::make_tuple(result, output_offsets); + } + if (config.add_dummy_prefix()) { + result.insert(result.begin(), ' '); + output_offsets.insert(output_offsets.begin(), 0); + } + // Greedely replace normalized_prefixes with normalized_replacements + if (config.normalized_prefixes() != nullptr && + config.normalized_replacements() != nullptr) { + const DoubleArrayTrie normalized_prefixes_matcher( + config.normalized_prefixes()->nodes()); + const auto norm_replace = [&config, &normalized_prefixes_matcher]( + const char* data, int len) { + return find_replacement(data, len, normalized_prefixes_matcher, + *config.normalized_replacements()); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, norm_replace); + } + if (config.remove_extra_whitespaces()) { + std::tie(result, output_offsets) = + process_string(result, output_offsets, remove_extra_whitespaces); + if (!result.empty() && is_whitespace(result.back())) { + result.pop_back(); + output_offsets.pop_back(); + } + } + if (config.escape_whitespaces()) { + const auto replace_whitespaces = [](const char* data, int len) { + if (len > 0 && is_whitespace(*data)) { + return std::make_tuple(1, utils::string_view(kSpaceSymbol)); + } + return std::make_tuple(0, utils::string_view(nullptr, 0)); + }; + std::tie(result, output_offsets) = + process_string(result, output_offsets, replace_whitespaces); + } + + return std::make_tuple(result, output_offsets); +} + +EncoderResult EncodeNormalizedString(const std::string& str, + const std::vector& offsets, + const EncoderConfig& config, bool add_bos, + bool add_eos, bool reverse) { + const DoubleArrayTrie piece_matcher(config.pieces()->nodes()); + const flatbuffers::Vector* piece_scores = config.pieces_scores(); + const int unknown_code = config.unknown_code(); + const float unknown_penalty = config.unknown_penalty(); + struct LatticeElement { + float score = 0; + int code = -1; + int prev_position = -1; + LatticeElement(float score_, int code_, int prev_position_) + : score(score_), code(code_), prev_position(prev_position_) {} + LatticeElement() {} + }; + const int length = str.length(); + std::vector lattice(length + 1); + for (int i = 0; i < length; ++i) { + if (i > 0 && lattice[i].prev_position < 0) { + // This state is unreachable. + continue; + } + if (unknown_code >= 0) { + // Put unknown code. + const float penalized_score = lattice[i].score + unknown_penalty; + const int pos = i + 1; + LatticeElement& current_element = lattice[pos]; + if (current_element.prev_position < 0 || + current_element.score < penalized_score) { + current_element = LatticeElement( + penalized_score, unknown_code, + // If the current state is already reached by unknown code, merge + // states. + lattice[i].code == unknown_code ? lattice[i].prev_position : i); + } + } + auto lattice_update = [&lattice, i, + piece_scores](const DoubleArrayTrie::Match& m) { + LatticeElement& target_element = lattice[i + m.match_length]; + const float score = lattice[i].score + (*piece_scores)[m.id]; + if (target_element.prev_position < 0 || target_element.score < score) { + target_element = LatticeElement(score, m.id, i); + } + }; + piece_matcher.IteratePrefixMatches( + utils::string_view(str.data() + i, length - i), lattice_update); + } + + EncoderResult result; + if (add_eos) { + result.codes.push_back(config.end_code()); + result.offsets.push_back(length); + } + if (lattice[length].prev_position >= 0) { + for (int pos = length; pos > 0;) { + auto code = lattice[pos].code; + if (code != config.unknown_code()) { + code += config.encoding_offset(); + } + result.codes.push_back(code); + pos = lattice[pos].prev_position; + result.offsets.push_back(offsets[pos]); + } + } + if (add_bos) { + result.codes.push_back(config.start_code()); + result.offsets.push_back(0); + } + if (!reverse) { + std::reverse(result.codes.begin(), result.codes.end()); + std::reverse(result.offsets.begin(), result.offsets.end()); + } + return result; +} + +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse) { + // Get the config from the buffer. + const EncoderConfig* config = GetEncoderConfig(config_buffer); + if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { + EncoderResult result; + result.type = EncoderResultType::WRONG_CONFIG; + return result; + } + std::string normalized_string; + std::vector offsets; + std::tie(normalized_string, offsets) = NormalizeString(string, *config); + return EncodeNormalizedString(normalized_string, offsets, *config, add_bos, + add_eos, reverse); +} + +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h new file mode 100644 index 000000000..849a47849 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h @@ -0,0 +1,46 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ + +// Sentencepiece encoder optimized with memmapped model. + +#include +#include +#include + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 }; + +struct EncoderResult { + EncoderResultType type = EncoderResultType::SUCCESS; + std::vector codes; + std::vector offsets; +}; +std::tuple> NormalizeString( + const std::string& in_string, const EncoderConfig& config); + +// Encodes one string and returns ids and offsets. Takes the configuration as a +// type-erased buffer. +EncoderResult EncodeString(const std::string& string, const void* config_buffer, + bool add_bos, bool add_eos, bool reverse); + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc new file mode 100644 index 000000000..e65bd1850 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder_test.cc @@ -0,0 +1,171 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" + +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" +#include "src/sentencepiece.pb.h" +#include "src/sentencepiece_processor.h" +#include "tensorflow/core/platform/env.h" + +namespace mediapipe::tflite_operations::sentencepiece { + +namespace internal { + +tensorflow::Status TFReadFileToString(const std::string& filepath, + std::string* data) { + return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, + data); +} + +absl::Status StdReadFileToString(const std::string& filepath, + std::string* data) { + std::ifstream infile(filepath); + if (!infile.is_open()) { + return absl::NotFoundError( + absl::StrFormat("Error when opening %s", filepath)); + } + std::string contents((std::istreambuf_iterator(infile)), + (std::istreambuf_iterator())); + data->append(contents); + infile.close(); + return absl::OkStatus(); +} +} // namespace internal + +namespace { + +using ::mediapipe::file::JoinPath; + +static char kConfigFilePath[] = + "/mediapipe/tasks/cc/text/custom_ops/" + "sentencepiece/testdata/sentencepiece.model"; + +TEST(OptimizedEncoder, NormalizeStringWhitestpaces) { + flatbuffers::FlatBufferBuilder builder(1024); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_add_dummy_prefix(true); + ecb.add_escape_whitespaces(true); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("x y", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3)); + } + { + const auto result = NormalizeString("\tx y\n", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); + EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4)); + } +} + +TEST(OptimizedEncoder, NormalizeStringReplacement) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector norm_prefixes = {"A", "AA", "AAA", "AAAA"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4"; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9})); + const auto norm_r = builder.CreateVector( + reinterpret_cast(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(false); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("ABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, "A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9)); + } +} + +TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) { + flatbuffers::FlatBufferBuilder builder(1024); + const std::vector norm_prefixes = {"A", "AA", "AAA", "AAAA", + "X"}; + const char norm_replacements[] = "A1\0A2\0A3\0A4\0 "; + const auto trie_vector = + builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12})); + const auto norm_r = builder.CreateVector( + reinterpret_cast(norm_replacements), + sizeof(norm_replacements)); + TrieBuilder trie_builder(builder); + trie_builder.add_nodes(trie_vector); + const auto norm_p = trie_builder.Finish(); + EncoderConfigBuilder ecb(builder); + ecb.add_remove_extra_whitespaces(true); + ecb.add_normalized_prefixes(norm_p); + ecb.add_normalized_replacements(norm_r); + FinishEncoderConfigBuffer(builder, ecb.Finish()); + const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); + { + const auto result = NormalizeString("XXABAABAAABAAAA", *config); + const auto res_string = std::get<0>(result); + const auto offsets = std::get<1>(result); + EXPECT_EQ(res_string, " A1BA2BA3BA4"); + EXPECT_THAT(offsets, + ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11)); + } +} + +TEST(OptimizedEncoder, ConfigConverter) { + std::string config; + auto status = + internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config); + ASSERT_TRUE(status.ok()); + + ::sentencepiece::SentencePieceProcessor processor; + ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); + const auto converted_model = ConvertSentencepieceModel(config); + const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); + const auto encoded = + EncodeString(test_string, converted_model.data(), false, false, false); + ASSERT_EQ(encoded.codes.size(), encoded.offsets.size()); + + ::sentencepiece::SentencePieceText reference_encoded; + ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); + EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size()); + for (int i = 0; i < encoded.codes.size(); ++i) { + EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id()); + EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin()); + } +} + +} // namespace +} // namespace mediapipe::tflite_operations::sentencepiece diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h new file mode 100644 index 000000000..faf481844 --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ + +namespace mediapipe::tflite_operations::sentencepiece { + +// The constant is copied from +// https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc +constexpr float kUnkPenalty = 10.0; + +// These constants are copied from +// https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc +// +// Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK). +constexpr char kSpaceSymbol[] = "\xe2\x96\x81"; + +// Encodes into U+2047 (DOUBLE QUESTION MARK), +// since this character can be useful both for user and +// developer. We can easily figure out that is emitted. +constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 "; + +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc new file mode 100644 index 000000000..468a3a54f --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.cc @@ -0,0 +1,129 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" + +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/context.h" +#include "tensorflow/lite/kernels/internal/tensor.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace mediapipe::tflite_operations { +namespace sentencepiece::tokenizer { +namespace { + +using ::tflite::SetTensorToDynamic; + +constexpr int kSPModelIndex = 0; +constexpr int kInputIndex = 1; +constexpr int kAddBOSInput = 4; +constexpr int kAddEOSInput = 5; +constexpr int kReverseInput = 6; + +constexpr int kOutputValuesInd = 0; +constexpr int kOutputSplitsInd = 1; + +TfLiteIntArray* CreateSizeArray(const std::initializer_list& sizes) { + TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size()); + int index = 0; + for (const int size : sizes) { + array_size->data[index++] = size; + } + return array_size; +} +} // namespace + +// Initializes text encoder object from serialized parameters. +void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, + size_t /*length*/) { + return nullptr; +} +void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // TODO: Add checks for input and output tensors. + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + SetTensorToDynamic(&output_values); + + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + SetTensorToDynamic(&output_splits); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor& model_tensor = + context->tensors[node->inputs->data[kSPModelIndex]]; + const auto model_buffer_data = model_tensor.data.data; + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[kInputIndex]]; + + const TfLiteTensor add_bos_tensor = + context->tensors[node->inputs->data[kAddBOSInput]]; + const bool add_bos = add_bos_tensor.data.b[0]; + const TfLiteTensor add_eos_tensor = + context->tensors[node->inputs->data[kAddEOSInput]]; + const bool add_eos = add_eos_tensor.data.b[0]; + const TfLiteTensor reverse_tensor = + context->tensors[node->inputs->data[kReverseInput]]; + const bool reverse = reverse_tensor.data.b[0]; + + std::vector encoded; + std::vector splits; + const int num_strings = tflite::GetStringCount(&input_text); + for (int i = 0; i < num_strings; ++i) { + const auto strref = tflite::GetString(&input_text, i); + const auto res = EncodeString(std::string(strref.str, strref.len), + model_buffer_data, add_bos, add_eos, reverse); + TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS, + "Sentencepiece conversion failed"); + std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded)); + splits.emplace_back(encoded.size()); + } + + TfLiteTensor& output_values = + context->tensors[node->outputs->data[kOutputValuesInd]]; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor( + context, &output_values, + CreateSizeArray({static_cast(encoded.size())}))); + int32_t* output_values_flat = output_values.data.i32; + std::copy(encoded.begin(), encoded.end(), output_values_flat); + TfLiteTensor& output_splits = + context->tensors[node->outputs->data[kOutputSplitsInd]]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_splits, + CreateSizeArray({static_cast(splits.size() + 1)}))); + int32_t* output_splits_flat = output_splits.data.i32; + *output_splits_flat = 0; + std::copy(splits.begin(), splits.end(), output_splits_flat + 1); + return kTfLiteOk; +} +} // namespace sentencepiece::tokenizer + +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() { + static TfLiteRegistration r = { + sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free, + sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval}; + return &r; +} + +} // namespace mediapipe::tflite_operations diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h new file mode 100644 index 000000000..8a9fa8aef --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace mediapipe::tflite_operations { + +TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); + +} // namespace mediapipe::tflite_operations + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model new file mode 100644 index 000000000..041188ffd Binary files /dev/null and b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model differ diff --git a/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h new file mode 100644 index 000000000..c1b7728cc --- /dev/null +++ b/mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h @@ -0,0 +1,60 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ + +#include +#include + +namespace mediapipe::tflite_operations::sentencepiece { + +// AOSP and WASM doesn't support string_view, +// we put here a minimal re-implementation. +namespace utils { + +class string_view { + public: + explicit string_view(const std::string& s) + : str_(s.data()), len_(s.length()) {} + string_view(const char* str, int len) : str_(str), len_(len) {} + // A constructor from c string. + explicit string_view(const char* s) : str_(s), len_(strlen(s)) {} + + int length() const { return len_; } + const char* data() const { return str_; } + bool empty() const { return len_ == 0; } + unsigned char at(int i) const { return str_[i]; } + + private: + const char* str_ = nullptr; + const int len_ = 0; +}; + +inline std::ostream& operator<<(std::ostream& os, const string_view& sv) { + os << std::string(sv.data(), sv.length()); + return os; +} +inline bool operator==(const string_view& view1, const string_view& view2) { + if (view1.length() != view2.length()) { + return false; + } + return memcmp(view1.data(), view2.data(), view1.length()) == 0; +} + +} // namespace utils +} // namespace mediapipe::tflite_operations::sentencepiece + +#endif // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc b/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc index 92dc493e0..2e9a58409 100644 --- a/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc +++ b/mediapipe/tasks/cc/text/language_detector/language_detector_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe::tasks::text::language_detector { namespace { @@ -75,7 +75,7 @@ absl::Status MatchesLanguageDetectorResult( } // namespace -class LanguageDetectorTest : public tflite_shims::testing::Test {}; +class LanguageDetectorTest : public tflite::testing::Test {}; TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) { auto options = std::make_unique(); diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 3c9c3fc0e..4bf773270 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -89,7 +89,7 @@ cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_sentencepiece//src:sentencepiece_processor", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], ) diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 98ca9e903..f800a0e52 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -36,7 +36,7 @@ limitations under the License. #include "mediapipe/tasks/cc/components/containers/category.h" #include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe::tasks::text::text_classifier { namespace { @@ -87,7 +87,7 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, } // namespace -class TextClassifierTest : public tflite_shims::testing::Test {}; +class TextClassifierTest : public tflite::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { auto options = std::make_unique(); diff --git a/mediapipe/tasks/cc/text/text_embedder/BUILD b/mediapipe/tasks/cc/text/text_embedder/BUILD index c6a2616b0..addb971f1 100644 --- a/mediapipe/tasks/cc/text/text_embedder/BUILD +++ b/mediapipe/tasks/cc/text/text_embedder/BUILD @@ -91,6 +91,6 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_sentencepiece//src:sentencepiece_processor", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], ) diff --git a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc index 533d829b9..474f0ca35 100644 --- a/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc +++ b/mediapipe/tasks/cc/text/text_embedder/text_embedder_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/components/containers/embedding_result.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe::tasks::text::text_embedder { namespace { @@ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite"; // Embedding model with regex preprocessing. constexpr char kRegexOneEmbeddingModel[] = "regex_one_embedding_with_metadata.tflite"; +constexpr char kUniversalSentenceEncoderModel[] = + "universal_sentence_encoder_qa_with_metadata.tflite"; // Tolerance for embedding vector coordinate values. constexpr float kEpsilon = 1e-4; @@ -49,7 +51,7 @@ using ::mediapipe::file::JoinPath; using ::testing::HasSubstr; using ::testing::Optional; -class EmbedderTest : public tflite_shims::testing::Test {}; +class EmbedderTest : public tflite::testing::Test {}; TEST_F(EmbedderTest, FailsWithMissingModel) { auto text_embedder = @@ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) { MP_ASSERT_OK(text_embedder->Close()); } +TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + auto result0, + text_embedder->Embed("it's a charming and often affecting journey")); + ASSERT_EQ(result0.embeddings.size(), 1); + ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100); + ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon); + + MP_ASSERT_OK_AND_ASSIGN( + auto result1, text_embedder->Embed("what a great and fantastic trip")); + ASSERT_EQ(result1.embeddings.size(), 1); + ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100); + ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { auto options = std::make_unique(); options->base_options.model_asset_path = @@ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { MP_ASSERT_OK(text_embedder->Close()); } +TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr text_embedder, + TextEmbedder::Create(std::move(options))); + + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result0, + text_embedder->Embed("When you go to this restaurant, they hold the " + "pancake upside-down before they hand it " + "to you. It's a great gimmick.")); + MP_ASSERT_OK_AND_ASSIGN( + TextEmbedderResult result1, + text_embedder->Embed( + "Let's make a plan to steal the declaration of independence.")); + + // Check cosine similarity. + MP_ASSERT_OK_AND_ASSIGN( + double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], + result1.embeddings[0])); + EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy); + + MP_ASSERT_OK(text_embedder->Close()); +} + } // namespace } // namespace mediapipe::tasks::text::text_embedder diff --git a/mediapipe/tasks/cc/text/utils/BUILD b/mediapipe/tasks/cc/text/utils/BUILD index 7b979189c..15af7683b 100644 --- a/mediapipe/tasks/cc/text/utils/BUILD +++ b/mediapipe/tasks/cc/text/utils/BUILD @@ -81,6 +81,6 @@ cc_test( "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], ) diff --git a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc index 3e29ff0c3..2ec5686f7 100644 --- a/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc +++ b/mediapipe/tasks/cc/text/utils/text_model_utils_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" #include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe::tasks::text::utils { @@ -76,7 +76,7 @@ absl::StatusOr GetModelTypeFromFile( } // namespace -class TextModelUtilsTest : public tflite_shims::testing::Test {}; +class TextModelUtilsTest : public tflite::testing::Test {}; TEST_F(TextModelUtilsTest, BertClassifierModelTest) { MP_ASSERT_OK_AND_ASSIGN(auto model_type, diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc index b9351b891..5c342a8e9 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_blendshapes_graph_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -105,7 +105,7 @@ absl::StatusOr> CreateTaskRunner() { graph.GetConfig(), absl::make_unique()); } -class FaceBlendshapesTest : public tflite_shims::testing::Test {}; +class FaceBlendshapesTest : public tflite::testing::Test {}; TEST_F(FaceBlendshapesTest, SmokeTest) { // Prepare graph inputs. diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc index 411693ecf..97af42da7 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc @@ -67,7 +67,7 @@ constexpr char kPortraitExpectedFaceLandmarksName[] = "portrait_expected_face_landmarks.pbtxt"; constexpr char kPortraitExpectedBlendshapesName[] = "portrait_expected_blendshapes.pbtxt"; -constexpr char kPortaitExpectedFaceGeomertyName[] = +constexpr char kPortraitExpectedFaceGeometryName[] = "portrait_expected_face_geometry.pbtxt"; constexpr float kLandmarksDiffMargin = 0.03; @@ -100,7 +100,7 @@ struct FaceLandmarkerTestParams { mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() { auto face_geometry = GetExpectedProto( - kPortaitExpectedFaceGeomertyName); + kPortraitExpectedFaceGeometryName); return face_geometry.pose_transform_matrix(); } diff --git a/mediapipe/tasks/cc/vision/face_stylizer/BUILD b/mediapipe/tasks/cc/vision/face_stylizer/BUILD index bdbf340b8..27b2f482d 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/BUILD +++ b/mediapipe/tasks/cc/vision/face_stylizer/BUILD @@ -23,18 +23,12 @@ cc_library( srcs = ["face_stylizer_graph.cc"], deps = [ "//mediapipe/calculators/core:split_vector_calculator_cc_proto", - "//mediapipe/calculators/image:image_cropping_calculator", - "//mediapipe/calculators/image:image_cropping_calculator_cc_proto", - "//mediapipe/calculators/image:warp_affine_calculator", - "//mediapipe/calculators/image:warp_affine_calculator_cc_proto", + "//mediapipe/calculators/image:image_clone_calculator_cc_proto", "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/calculators/util:face_to_rect_calculator", - "//mediapipe/calculators/util:from_image_calculator", - "//mediapipe/calculators/util:inverse_matrix_calculator", "//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto", - "//mediapipe/calculators/util:to_image_calculator", "//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:image", @@ -53,7 +47,6 @@ cc_library( "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", - "//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator", "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h index 14c23b7a8..36bb11bd7 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h @@ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // The input image can be of any size with format RGB or RGBA. // When no face is detected on the input image, the method returns a // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. To ensure that the output image has reasonable quality, the stylized - // output image size is the smaller of the model output size and the size of - // the 'region_of_interest' specified in 'image_processing_options'. + // face. The stylized output image size is the same as the model output size. absl::StatusOr> Stylize( mediapipe::Image image, std::optional image_processing_options = @@ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // must be monotonically increasing. // When no face is detected on the input image, the method returns a // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. To ensure that the output image has reasonable quality, the stylized - // output image size is the smaller of the model output size and the size of - // the 'region_of_interest' specified in 'image_processing_options'. + // face. The stylized output image size is the same as the model output size. absl::StatusOr> StylizeForVideo( mediapipe::Image image, int64_t timestamp_ms, std::optional image_processing_options = @@ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // The "result_callback" provides: // - When no face is detected on the input image, the method returns a // std::nullopt. Otherwise, returns the stylized image of the most visible - // face. To ensure that the output image has reasonable quality, the - // stylized output image size is the smaller of the model output size and - // the size of the 'region_of_interest' specified in - // 'image_processing_options'. + // face. The stylized output image size is the same as the model output + // size. // - The input timestamp in milliseconds. absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms, std::optional diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc index bf717a71d..27b8dacc1 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer_graph.cc @@ -19,8 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "mediapipe/calculators/core/split_vector_calculator.pb.h" -#include "mediapipe/calculators/image/image_cropping_calculator.pb.h" -#include "mediapipe/calculators/image/warp_affine_calculator.pb.h" +#include "mediapipe/calculators/image/image_clone_calculator.pb.h" #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" #include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h" #include "mediapipe/framework/api2/builder.h" @@ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph { image_in >> preprocessing.In(kImageTag); face_rect >> preprocessing.In(kNormRectTag); auto preprocessed_tensors = preprocessing.Out(kTensorsTag); - auto transform_matrix = preprocessing.Out(kMatrixTag); // Adds inference subgraph and connects its input stream to the output // tensors produced by the ImageToTensorCalculator. @@ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph { model_output_tensors >> tensors_to_image.In(kTensorsTag); auto tensor_image = tensors_to_image.Out(kImageTag); - auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); - transform_matrix >> inverse_matrix.In(kMatrixTag); - auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag); + auto& image_converter = graph.AddNode("ImageCloneCalculator"); + image_converter.GetOptions() + .set_output_on_gpu(false); + tensor_image >> image_converter.In(""); - auto& warp_affine = graph.AddNode("WarpAffineCalculator"); - auto& warp_affine_options = - warp_affine.GetOptions(); - warp_affine_options.set_border_mode( - WarpAffineCalculatorOptions::BORDER_ZERO); - warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT); - tensor_image >> warp_affine.In(kImageTag); - inverse_transform_matrix >> warp_affine.In(kMatrixTag); - image_size >> warp_affine.In(kOutputSizeTag); - auto image_to_crop = warp_affine.Out(kImageTag); - - // The following calculators are for cropping and resizing the output image - // based on the roi and the model output size. As the WarpAffineCalculator - // rotates the image based on the transform matrix, the rotation info in the - // rect proto is stripped to prevent the ImageCroppingCalculator from - // performing extra rotation. - auto& strip_rotation = - graph.AddNode("mediapipe.tasks.StripRotationCalculator"); - face_rect >> strip_rotation.In(kNormRectTag); - auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag); - auto& from_image = graph.AddNode("FromImageCalculator"); - image_to_crop >> from_image.In(kImageTag); - auto& image_cropping = graph.AddNode("ImageCroppingCalculator"); - auto& image_cropping_opts = - image_cropping.GetOptions(); - image_cropping_opts.set_output_max_width( - image_to_tensor_options.output_tensor_width()); - image_cropping_opts.set_output_max_height( - image_to_tensor_options.output_tensor_height()); - norm_rect_no_rotation >> image_cropping.In(kNormRectTag); - auto& to_image = graph.AddNode("ToImageCalculator"); - // ImageCroppingCalculator currently doesn't support mediapipe::Image, the - // graph selects its cpu or gpu path based on the image preprocessing - // backend. - if (use_gpu) { - from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag); - image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag); - } else { - from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag); - image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag); - } - - return {{/*stylized_image=*/to_image.Out(kImageTag).Cast(), + return {{/*stylized_image=*/image_converter.Out("").Cast(), /*original_image=*/preprocessing.Out(kImageTag).Cast()}}; } }; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc index c28df2c05..fc73e7787 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph_test.cc @@ -43,7 +43,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -137,7 +137,7 @@ absl::StatusOr> CreateTaskRunner() { graph.GetConfig(), absl::make_unique()); } -class HandLandmarkerTest : public tflite_shims::testing::Test {}; +class HandLandmarkerTest : public tflite::testing::Test {}; TEST_F(HandLandmarkerTest, Succeeds) { MP_ASSERT_OK_AND_ASSIGN( diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc index 48e517977..bb7d1a905 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_test.cc @@ -41,7 +41,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index f7fa83a11..9409303ba 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -59,7 +59,6 @@ using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::utils::AllowIf; -using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::hand_landmarker::proto:: HandLandmarksDetectorGraphOptions; using LabelItems = mediapipe::proto_ns::Map; diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc index f28907d2f..bbf3a7cde 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph_test.cc @@ -146,7 +146,7 @@ absl::StatusOr> CreateSingleHandTaskRunner( return TaskRunner::Create( graph.GetConfig(), - absl::make_unique()); + absl::make_unique()); } // Helper function to create a Multi Hand Landmark TaskRunner. @@ -188,7 +188,7 @@ absl::StatusOr> CreateMultiHandTaskRunner( return TaskRunner::Create( graph.GetConfig(), - absl::make_unique()); + absl::make_unique()); } NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index a9d6f55be..e8812d9fd 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -39,9 +39,9 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -148,7 +148,7 @@ class MobileNetQuantizedOpResolverMissingOps const MobileNetQuantizedOpResolverMissingOps& r) = delete; }; -class CreateTest : public tflite_shims::testing::Test {}; +class CreateTest : public tflite::testing::Test {}; TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) { auto options = std::make_unique(); @@ -265,7 +265,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) { MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); } -class ImageModeTest : public tflite_shims::testing::Test {}; +class ImageModeTest : public tflite::testing::Test {}; TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN( @@ -605,7 +605,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); } -class VideoModeTest : public tflite_shims::testing::Test {}; +class VideoModeTest : public tflite::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN( @@ -707,7 +707,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK(image_classifier->Close()); } -class LiveStreamModeTest : public tflite_shims::testing::Test {}; +class LiveStreamModeTest : public tflite::testing::Test {}; TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN( diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index 748333f7d..7a0e9e9dc 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -30,9 +30,9 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -103,7 +103,7 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver { delete; }; -class CreateTest : public tflite_shims::testing::Test {}; +class CreateTest : public tflite::testing::Test {}; TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) { auto options = std::make_unique(); @@ -181,7 +181,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) { MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); } -class ImageModeTest : public tflite_shims::testing::Test {}; +class ImageModeTest : public tflite::testing::Test {}; TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN( @@ -410,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); } -class VideoModeTest : public tflite_shims::testing::Test {}; +class VideoModeTest : public tflite::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN( @@ -494,7 +494,7 @@ TEST_F(VideoModeTest, Succeeds) { MP_ASSERT_OK(image_embedder->Close()); } -class LiveStreamModeTest : public tflite_shims::testing::Test {}; +class LiveStreamModeTest : public tflite::testing::Test {}; TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index 0c5a61486..21f73e103 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -39,9 +39,9 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -180,7 +180,7 @@ class DeepLabOpResolver : public ::tflite::MutableOpResolver { DeepLabOpResolver(const DeepLabOpResolver& r) = delete; }; -class CreateFromOptionsTest : public tflite_shims::testing::Test {}; +class CreateFromOptionsTest : public tflite::testing::Test {}; class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { public: @@ -268,7 +268,7 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) { } } -class ImageModeTest : public tflite_shims::testing::Test {}; +class ImageModeTest : public tflite::testing::Test {}; TEST_F(ImageModeTest, SucceedsWithCategoryMask) { MP_ASSERT_OK_AND_ASSIGN( @@ -521,7 +521,7 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } -class VideoModeTest : public tflite_shims::testing::Test {}; +class VideoModeTest : public tflite::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN( @@ -581,7 +581,7 @@ TEST_F(VideoModeTest, Succeeds) { MP_ASSERT_OK(segmenter->Close()); } -class LiveStreamModeTest : public tflite_shims::testing::Test {}; +class LiveStreamModeTest : public tflite::testing::Test {}; TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc index 9d7111e75..b17b3b0d3 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -64,7 +64,6 @@ using ::mediapipe::CalculatorGraphConfig; using ::mediapipe::Image; using ::mediapipe::NormalizedRect; using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult; -using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: image_segmenter::proto::ImageSegmenterGraphOptions; diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc index 40c2bb342..c761678d0 100644 --- a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -39,9 +39,9 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/test_util.h" #include "testing/base/public/gmock.h" namespace mediapipe { @@ -124,7 +124,7 @@ MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold, similarity_threshold; } -class CreateFromOptionsTest : public tflite_shims::testing::Test {}; +class CreateFromOptionsTest : public tflite::testing::Test {}; class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { public: @@ -261,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P( [](const ::testing::TestParamInfo& info) { return info.param.test_name; }); -class ImageModeTest : public tflite_shims::testing::Test {}; +class ImageModeTest : public tflite::testing::Test {}; // TODO: fix this unit test after image segmenter handled post // processing correctly with rotated image. diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 178f95168..c992cf67e 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -43,9 +43,9 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/api/op_resolver.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/test_util.h" namespace tflite { namespace ops { @@ -159,7 +159,7 @@ class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver { MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete; }; -class CreateFromOptionsTest : public tflite_shims::testing::Test {}; +class CreateFromOptionsTest : public tflite::testing::Test {}; TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { auto options = std::make_unique(); @@ -332,7 +332,7 @@ TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) { // TODO: Add NumThreadsTest back after having an // "acceleration configuration" field in the ObjectDetectorOptions. -class ImageModeTest : public tflite_shims::testing::Test {}; +class ImageModeTest : public tflite::testing::Test {}; TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( @@ -618,7 +618,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); } -class VideoModeTest : public tflite_shims::testing::Test {}; +class VideoModeTest : public tflite::testing::Test {}; TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( @@ -673,7 +673,7 @@ TEST_F(VideoModeTest, Succeeds) { MP_ASSERT_OK(object_detector->Close()); } -class LiveStreamModeTest : public tflite_shims::testing::Test {}; +class LiveStreamModeTest : public tflite::testing::Test {}; TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD index 1b1b818ce..19f546257 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/pose_landmarker/BUILD @@ -97,8 +97,10 @@ cc_library( "//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "//mediapipe/util:graph_builder_utils", "@com_google_absl//absl/status:statusor", ], + alwayslink = 1, ) cc_library( diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc index 4c734c423..01c86c122 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker.cc @@ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; // limit the number of frames in flight. CalculatorGraphConfig CreateGraphConfig( std::unique_ptr options, - bool enable_flow_limiting) { + bool enable_flow_limiting, bool output_segmentation_masks) { api2::builder::Graph graph; auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName); subgraph.GetOptions().Swap(options.get()); graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kNormRectTag).SetName(kNormRectStreamName); - subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> - graph.Out(kSegmentationMaskTag); subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >> graph.Out(kNormLandmarksTag); subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >> @@ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig( .SetName(kPoseAuxiliaryLandmarksStreamName) >> graph.Out(kPoseAuxiliaryLandmarksTag); subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); + if (output_segmentation_masks) { + subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> + graph.Out(kSegmentationMaskTag); + } if (enable_flow_limiting) { return tasks::core::AddFlowLimiterCalculator( graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag); @@ -187,7 +189,8 @@ absl::StatusOr> PoseLandmarker::Create( PoseLandmarkerGraphOptionsProto>( CreateGraphConfig( std::move(options_proto), - options->running_mode == core::RunningMode::LIVE_STREAM), + options->running_mode == core::RunningMode::LIVE_STREAM, + options->output_segmentation_masks), std::move(options->base_options.op_resolver), options->running_mode, std::move(packets_callback)))); diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc index ae3a7482e..456a6efd1 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_graph.cc @@ -90,7 +90,7 @@ struct PoseLandmarkerOutputs { Source> auxiliary_landmark_lists; Source> pose_rects_next_frame; Source> pose_detections; - Source> segmentation_masks; + std::optional>> segmentation_masks; Source image; }; @@ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, // input_stream: "IMAGE:image_in" // input_stream: "NORM_RECT:norm_rect" // output_stream: "NORM_LANDMARKS:pose_landmarks" -// output_stream: "LANDMARKS:world_landmarks" -// output_stream: "NORM_LANDMAKRS:auxiliary_landmarks" +// output_stream: "WORLD_LANDMARKS:world_landmarks" +// output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks" // output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame" // output_stream: "POSE_RECTS:pose_rects" // output_stream: "SEGMENTATION_MASK:segmentation_masks" @@ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; + bool output_segmentation_masks = + HasOutput(sc->OriginalNode(), kSegmentationMaskTag); if (sc->Options() .base_options() .has_model_asset()) { @@ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); } - ASSIGN_OR_RETURN( - auto outs, - BuildPoseLandmarkerGraph( - *sc->MutableOptions(), - graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + ASSIGN_OR_RETURN(auto outs, + BuildPoseLandmarkerGraph( + *sc->MutableOptions(), + graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], + graph, output_segmentation_masks)); outs.landmark_lists >> graph[Output>(kNormLandmarksTag)]; outs.world_landmark_lists >> @@ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { kAuxiliaryLandmarksTag)]; outs.pose_rects_next_frame >> graph[Output>(kPoseRectsNextFrameTag)]; - outs.segmentation_masks >> - graph[Output>(kSegmentationMaskTag)]; outs.pose_detections >> graph[Output>(kDetectionsTag)]; outs.image >> graph[Output(kImageTag)]; + if (outs.segmentation_masks) { + *outs.segmentation_masks >> + graph[Output>(kSegmentationMaskTag)]; + } // TODO remove when support is fixed. // As mediapipe GraphBuilder currently doesn't support configuring @@ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { // graph: the mediapipe graph instance to be updated. absl::StatusOr BuildPoseLandmarkerGraph( PoseLandmarkerGraphOptions& tasks_options, Source image_in, - Source norm_rect_in, Graph& graph) { + Source norm_rect_in, Graph& graph, + bool output_segmentation_masks) { const int max_num_poses = tasks_options.pose_detector_graph_options().num_poses(); @@ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { auto pose_rects_for_next_frame = pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag) .Cast>(); - auto segmentation_masks = - pose_landmarks_detector_graph.Out(kSegmentationMaskTag) - .Cast>(); + std::optional>> segmentation_masks; + if (output_segmentation_masks) { + segmentation_masks = + pose_landmarks_detector_graph.Out(kSegmentationMaskTag) + .Cast>(); + } if (tasks_options.base_options().use_stream_mode()) { auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator"); diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc index 062d0746d..87bc97274 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_test.cc @@ -38,7 +38,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarker_result.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" #include "util/tuple/dump_vars.h" namespace mediapipe { diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc index c71fc2d58..f8488db02 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/vision/pose_landmarker/proto/pose_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" +#include "mediapipe/util/graph_builder_utils.h" namespace mediapipe { namespace tasks { @@ -48,6 +49,7 @@ using ::mediapipe::api2::Input; using ::mediapipe::api2::Output; using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Source; +using ::mediapipe::api2::builder::Stream; using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::vision::pose_landmarker::proto:: PoseLandmarksDetectorGraphOptions; @@ -89,7 +91,7 @@ struct SinglePoseLandmarkerOutputs { Source pose_rect_next_frame; Source pose_presence; Source pose_presence_score; - Source segmentation_mask; + std::optional> segmentation_mask; }; struct PoseLandmarkerOutputs { @@ -99,7 +101,7 @@ struct PoseLandmarkerOutputs { Source> pose_rects_next_frame; Source> presences; Source> presence_scores; - Source> segmentation_masks; + std::optional>> segmentation_masks; }; absl::Status SanityCheckOptions( @@ -269,16 +271,18 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { public: absl::StatusOr GetConfig( SubgraphContext* sc) override { + bool output_segmentation_mask = + HasOutput(sc->OriginalNode(), kSegmentationMaskTag); ASSIGN_OR_RETURN( const auto* model_resources, CreateModelResources(sc)); Graph graph; - ASSIGN_OR_RETURN( - auto pose_landmark_detection_outs, - BuildSinglePoseLandmarksDetectorGraph( - sc->Options(), *model_resources, - graph[Input(kImageTag)], - graph[Input::Optional(kNormRectTag)], graph)); + ASSIGN_OR_RETURN(auto pose_landmark_detection_outs, + BuildSinglePoseLandmarksDetectorGraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input::Optional(kNormRectTag)], + graph, output_segmentation_mask)); pose_landmark_detection_outs.pose_landmarks >> graph[Output(kLandmarksTag)]; pose_landmark_detection_outs.world_pose_landmarks >> @@ -291,8 +295,10 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { graph[Output(kPresenceTag)]; pose_landmark_detection_outs.pose_presence_score >> graph[Output(kPresenceScoreTag)]; - pose_landmark_detection_outs.segmentation_mask >> - graph[Output(kSegmentationMaskTag)]; + if (pose_landmark_detection_outs.segmentation_mask) { + *pose_landmark_detection_outs.segmentation_mask >> + graph[Output(kSegmentationMaskTag)]; + } return graph.GetConfig(); } @@ -302,7 +308,8 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { BuildSinglePoseLandmarksDetectorGraph( const PoseLandmarksDetectorGraphOptions& subgraph_options, const ModelResources& model_resources, Source image_in, - Source pose_rect, Graph& graph) { + Source pose_rect, Graph& graph, + bool output_segmentation_mask) { MP_RETURN_IF_ERROR(SanityCheckOptions(subgraph_options)); auto& preprocessing = graph.AddNode( @@ -380,17 +387,6 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { auto raw_landmarks = tensors_to_landmarks[Output(kNormLandmarksTag)]; - // Decodes the segmentation tensor into a mask image with pixel values in - // [0, 1] (1 for person and 0 for background). - auto& tensors_to_segmentation = - graph.AddNode("TensorsToSegmentationCalculator"); - ConfigureTensorsToSegmentationCalculator( - &tensors_to_segmentation - .GetOptions()); - ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag); - auto raw_segmentation_mask = - tensors_to_segmentation[Output(kMaskTag)]; - // Refines landmarks with the heatmap tensor. auto& refine_landmarks_from_heatmap = graph.AddNode("RefineLandmarksFromHeatmapCalculator"); @@ -493,20 +489,34 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { auto world_projected_landmarks = world_landmarks_projection.Out(kLandmarksTag).Cast(); - // Calculates the inverse transformation matrix. - auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); - matrix >> inverse_matrix.In(kMatrixTag); - auto inverted_matrix = inverse_matrix.Out(kMatrixTag); + std::optional> segmentation_mask; + if (output_segmentation_mask) { + // Decodes the segmentation tensor into a mask image with pixel values in + // [0, 1] (1 for person and 0 for background). + auto& tensors_to_segmentation = + graph.AddNode("TensorsToSegmentationCalculator"); + ConfigureTensorsToSegmentationCalculator( + &tensors_to_segmentation.GetOptions< + mediapipe::TensorsToSegmentationCalculatorOptions>()); + ensured_segmentation_tensors >> tensors_to_segmentation.In(kTensorsTag); + auto raw_segmentation_mask = + tensors_to_segmentation[Output(kMaskTag)]; - // Projects the segmentation mask from the letterboxed ROI back to the full - // image. - auto& warp_affine = graph.AddNode("WarpAffineCalculator"); - ConfigureWarpAffineCalculator( - &warp_affine.GetOptions()); - image_size >> warp_affine.In(kOutputSizeTag); - inverted_matrix >> warp_affine.In(kMatrixTag); - raw_segmentation_mask >> warp_affine.In(kImageTag); - auto projected_segmentation_mask = warp_affine.Out(kImageTag).Cast(); + // Calculates the inverse transformation matrix. + auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); + matrix >> inverse_matrix.In(kMatrixTag); + auto inverted_matrix = inverse_matrix.Out(kMatrixTag); + + // Projects the segmentation mask from the letterboxed ROI back to the + // full image. + auto& warp_affine = graph.AddNode("WarpAffineCalculator"); + ConfigureWarpAffineCalculator( + &warp_affine.GetOptions()); + image_size >> warp_affine.In(kOutputSizeTag); + inverted_matrix >> warp_affine.In(kMatrixTag); + raw_segmentation_mask >> warp_affine.In(kImageTag); + segmentation_mask = warp_affine.Out(kImageTag).Cast(); + } // Calculate region of interest based on auxiliary landmarks, to be used // in the next frame. Consists of LandmarksToDetection + @@ -541,7 +551,7 @@ class SinglePoseLandmarksDetectorGraph : public core::ModelTaskGraph { /* pose_rect_next_frame= */ pose_rect_next_frame, /* pose_presence= */ pose_presence, /* pose_presence_score= */ pose_presence_score, - /* segmentation_mask= */ projected_segmentation_mask, + /* segmentation_mask= */ segmentation_mask, }}; } }; @@ -613,12 +623,15 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; + bool output_segmentation_masks = + HasOutput(sc->OriginalNode(), kSegmentationMaskTag); ASSIGN_OR_RETURN( auto pose_landmark_detection_outputs, BuildPoseLandmarksDetectorGraph( sc->Options(), graph[Input(kImageTag)], - graph[Input>(kNormRectTag)], graph)); + graph[Input>(kNormRectTag)], graph, + output_segmentation_masks)); pose_landmark_detection_outputs.landmark_lists >> graph[Output>(kLandmarksTag)]; pose_landmark_detection_outputs.world_landmark_lists >> @@ -631,8 +644,10 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { graph[Output>(kPresenceTag)]; pose_landmark_detection_outputs.presence_scores >> graph[Output>(kPresenceScoreTag)]; - pose_landmark_detection_outputs.segmentation_masks >> - graph[Output>(kSegmentationMaskTag)]; + if (pose_landmark_detection_outputs.segmentation_masks) { + *pose_landmark_detection_outputs.segmentation_masks >> + graph[Output>(kSegmentationMaskTag)]; + } return graph.GetConfig(); } @@ -641,7 +656,8 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { absl::StatusOr BuildPoseLandmarksDetectorGraph( const PoseLandmarksDetectorGraphOptions& subgraph_options, Source image_in, - Source> multi_pose_rects, Graph& graph) { + Source> multi_pose_rects, Graph& graph, + bool output_segmentation_masks) { auto& begin_loop_multi_pose_rects = graph.AddNode("BeginLoopNormalizedRectCalculator"); image_in >> begin_loop_multi_pose_rects.In("CLONE"); @@ -664,7 +680,6 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { pose_landmark_subgraph.Out(kPoseRectNextFrameTag); auto presence = pose_landmark_subgraph.Out(kPresenceTag); auto presence_score = pose_landmark_subgraph.Out(kPresenceScoreTag); - auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag); auto& end_loop_landmarks = graph.AddNode("EndLoopNormalizedLandmarkListVectorCalculator"); @@ -708,11 +723,16 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { auto presence_scores = end_loop_presence_score[Output>(kIterableTag)]; - auto& end_loop_segmentation_mask = graph.AddNode("EndLoopImageCalculator"); - batch_end >> end_loop_segmentation_mask.In(kBatchEndTag); - segmentation_mask >> end_loop_segmentation_mask.In(kItemTag); - auto segmentation_masks = - end_loop_segmentation_mask[Output>(kIterableTag)]; + std::optional>> segmentation_masks_vector; + if (output_segmentation_masks) { + auto segmentation_mask = pose_landmark_subgraph.Out(kSegmentationMaskTag); + auto& end_loop_segmentation_mask = + graph.AddNode("EndLoopImageCalculator"); + batch_end >> end_loop_segmentation_mask.In(kBatchEndTag); + segmentation_mask >> end_loop_segmentation_mask.In(kItemTag); + segmentation_masks_vector = + end_loop_segmentation_mask[Output>(kIterableTag)]; + } return {{ /* landmark_lists= */ landmark_lists, @@ -721,7 +741,7 @@ class MultiplePoseLandmarksDetectorGraph : public core::ModelTaskGraph { /* pose_rects_next_frame= */ pose_rects_next_frame, /* presences= */ presences, /* presence_scores= */ presence_scores, - /* segmentation_masks= */ segmentation_masks, + /* segmentation_masks= */ segmentation_masks_vector, }}; } }; diff --git a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc index a8592061c..d5108decf 100644 --- a/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc +++ b/mediapipe/tasks/cc/vision/pose_landmarker/pose_landmarks_detector_graph_test.cc @@ -143,7 +143,7 @@ absl::StatusOr> CreateSinglePoseTaskRunner( return TaskRunner::Create( graph.GetConfig(), - absl::make_unique()); + absl::make_unique()); } // Helper function to create a Multi Pose Landmark TaskRunner. @@ -189,7 +189,7 @@ absl::StatusOr> CreateMultiPoseTaskRunner( return TaskRunner::Create( graph.GetConfig(), - absl::make_unique()); + absl::make_unique()); } NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { diff --git a/mediapipe/tasks/cc/vision/utils/BUILD b/mediapipe/tasks/cc/vision/utils/BUILD index 2d133cc98..7e5a4dc8c 100644 --- a/mediapipe/tasks/cc/vision/utils/BUILD +++ b/mediapipe/tasks/cc/vision/utils/BUILD @@ -50,7 +50,7 @@ cc_test_with_tflite( tflite_deps = [ ":image_tensor_specs", "//mediapipe/tasks/cc/core:model_resources", - "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", + "@org_tensorflow//tensorflow/lite:test_util", ], deps = [ "//mediapipe/framework/deps:file_path", diff --git a/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc b/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc index 5d6fcf98c..8c7b7d595 100644 --- a/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc +++ b/mediapipe/tasks/cc/vision/utils/image_tensor_specs_test.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" -#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/test_util.h" namespace mediapipe { namespace tasks { @@ -69,7 +69,7 @@ constexpr char kMobileNetMetadata[] = constexpr char kMobileNetQuantizedPartialMetadata[] = "mobilenet_v1_0.25_224_quant_without_subgraph_metadata.tflite"; -class ImageTensorSpecsTest : public tflite_shims::testing::Test {}; +class ImageTensorSpecsTest : public tflite::testing::Test {}; TEST_F(ImageTensorSpecsTest, BuildInputImageTensorSpecsWorks) { auto model_file = std::make_unique(); diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h index cd464c6a1..bbc9aa8a5 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.h @@ -90,7 +90,7 @@ NS_SWIFT_NAME(ClassificationResult) * amount of data to process might exceed the maximum size that the model can process: to solve * this, the input data is split into multiple chunks starting at different timestamps. */ -@property(nonatomic, readonly) NSInteger timestampMs; +@property(nonatomic, readonly) NSInteger timestampInMilliseconds; /** * Initializes a new `MPPClassificationResult` with the given array of classifications and time @@ -98,14 +98,15 @@ NS_SWIFT_NAME(ClassificationResult) * * @param classifications An Array of `MPPClassifications` objects containing the predicted * categories for each head of the model. - * @param timestampMs The timestamp (in milliseconds) of the start of the chunk of data + * @param timestampInMilliseconds The timestamp (in milliseconds) of the start of the chunk of data * corresponding to these results. * * @return An instance of `MPPClassificationResult` initialized with the given array of - * classifications and timestampMs. + * classifications and timestamp (in milliseconds). */ - (instancetype)initWithClassifications:(NSArray *)classifications - timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + NS_DESIGNATED_INITIALIZER; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m index 6d42d22ca..8d9440492 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPClassificationResult.m @@ -38,11 +38,11 @@ @implementation MPPClassificationResult - (instancetype)initWithClassifications:(NSArray *)classifications - timestampMs:(NSInteger)timestampMs { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { self = [super init]; if (self) { _classifications = classifications; - _timestampMs = timestampMs; + _timestampInMilliseconds = timestampInMilliseconds; } return self; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h index 8fd9b9dff..4cfd8890d 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.h @@ -33,7 +33,7 @@ NS_SWIFT_NAME(EmbeddingResult) * cases, the amount of data to process might exceed the maximum size that the model can process. To * solve this, the input data is split into multiple chunks starting at different timestamps. */ -@property(nonatomic, readonly) NSInteger timestampMs; +@property(nonatomic, readonly) NSInteger timestampInMilliseconds; /** * Initializes a new `MPPEmbedding` with the given array of embeddings and timestamp (in @@ -41,14 +41,14 @@ NS_SWIFT_NAME(EmbeddingResult) * * @param embeddings An Array of `MPPEmbedding` objects containing the embedding results for each * head of the model. - * @param timestampMs The optional timestamp (in milliseconds) of the start of the chunk of data - * corresponding to these results. Pass `0` if timestamp is absent. + * @param timestampInMilliseconds The optional timestamp (in milliseconds) of the start of the chunk + * of data corresponding to these results. Pass `0` if timestamp is absent. * * @return An instance of `MPPEmbeddingResult` initialized with the given array of embeddings and - * timestampMs. + * timestamp (in milliseconds). */ - (instancetype)initWithEmbeddings:(NSArray *)embeddings - timestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds NS_DESIGNATED_INITIALIZER; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m index 56dd30fdd..1f4828583 100644 --- a/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m +++ b/mediapipe/tasks/ios/components/containers/sources/MPPEmbeddingResult.m @@ -17,11 +17,11 @@ @implementation MPPEmbeddingResult - (instancetype)initWithEmbeddings:(NSArray *)embeddings - timestampMs:(NSInteger)timestampMs { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { self = [super init]; if (self) { _embeddings = embeddings; - _timestampMs = timestampMs; + _timestampInMilliseconds = timestampInMilliseconds; } return self; diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm index ff0983139..12cfa5627 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPCategory+Helpers.mm @@ -21,20 +21,20 @@ using ClassificationProto = ::mediapipe::Classification; @implementation MPPCategory (Helpers) -+ (MPPCategory *)categoryWithProto:(const ClassificationProto &)clasificationProto { ++ (MPPCategory *)categoryWithProto:(const ClassificationProto &)classificationProto { NSString *categoryName; NSString *displayName; - if (clasificationProto.has_label()) { - categoryName = [NSString stringWithCppString:clasificationProto.label()]; + if (classificationProto.has_label()) { + categoryName = [NSString stringWithCppString:classificationProto.label()]; } - if (clasificationProto.has_display_name()) { - displayName = [NSString stringWithCppString:clasificationProto.display_name()]; + if (classificationProto.has_display_name()) { + displayName = [NSString stringWithCppString:classificationProto.display_name()]; } - return [[MPPCategory alloc] initWithIndex:clasificationProto.index() - score:clasificationProto.score() + return [[MPPCategory alloc] initWithIndex:classificationProto.index() + score:classificationProto.score() categoryName:categoryName displayName:displayName]; } diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm index b02b032bb..47f1cf45c 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.mm @@ -55,13 +55,13 @@ using ClassificationResultProto = [classifications addObject:[MPPClassifications classificationsWithProto:classificationsProto]]; } - NSInteger timestampMs = 0; + NSInteger timestampInMilliseconds = 0; if (classificationResultProto.has_timestamp_ms()) { - timestampMs = (NSInteger)classificationResultProto.timestamp_ms(); + timestampInMilliseconds = (NSInteger)classificationResultProto.timestamp_ms(); } return [[MPPClassificationResult alloc] initWithClassifications:classifications - timestampMs:timestampMs]; + timestampInMilliseconds:timestampInMilliseconds]; ; } diff --git a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm index f9863e9ca..cf5569c07 100644 --- a/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm +++ b/mediapipe/tasks/ios/components/containers/utils/sources/MPPEmbeddingResult+Helpers.mm @@ -31,12 +31,13 @@ using EmbeddingResultProto = ::mediapipe::tasks::components::containers::proto:: [embeddings addObject:[MPPEmbedding embeddingWithProto:embeddingProto]]; } - NSInteger timestampMs = 0; + NSInteger timestampInMilliseconds = 0; if (embeddingResultProto.has_timestamp_ms()) { - timestampMs = (NSInteger)embeddingResultProto.timestamp_ms(); + timestampInMilliseconds = (NSInteger)embeddingResultProto.timestamp_ms(); } - return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings timestampMs:timestampMs]; + return [[MPPEmbeddingResult alloc] initWithEmbeddings:embeddings + timestampInMilliseconds:timestampInMilliseconds]; } @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h index 4ee7b2fc6..664a94ba6 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.h @@ -26,11 +26,12 @@ NS_SWIFT_NAME(TaskResult) /** * Timestamp that is associated with the task result object. */ -@property(nonatomic, assign, readonly) NSInteger timestampMs; +@property(nonatomic, assign, readonly) NSInteger timestampInMilliseconds; - (instancetype)init NS_UNAVAILABLE; -- (instancetype)initWithTimestampMs:(NSInteger)timestampMs NS_DESIGNATED_INITIALIZER; +- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds + NS_DESIGNATED_INITIALIZER; @end diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m index 6c08014ff..8a7fa6b5b 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskResult.m +++ b/mediapipe/tasks/ios/core/sources/MPPTaskResult.m @@ -16,16 +16,16 @@ @implementation MPPTaskResult -- (instancetype)initWithTimestampMs:(NSInteger)timestampMs { +- (instancetype)initWithTimestampInMilliseconds:(NSInteger)timestampInMilliseconds { self = [super init]; if (self) { - _timestampMs = timestampMs; + _timestampInMilliseconds = timestampInMilliseconds; } return self; } - (id)copyWithZone:(NSZone *)zone { - return [[MPPTaskResult alloc] initWithTimestampMs:self.timestampMs]; + return [[MPPTaskResult alloc] initWithTimestampInMilliseconds:self.timestampInMilliseconds]; } @end diff --git a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m index 613239944..d3a027b6c 100644 --- a/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m +++ b/mediapipe/tasks/ios/test/vision/image_classifier/MPPImageClassifierTests.m @@ -487,7 +487,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; NSError *liveStreamApiCallError; XCTAssertFalse([imageClassifier classifyAsyncImage:image - timestampMs:0 + timestampInMilliseconds:0 error:&liveStreamApiCallError]); NSError *expectedLiveStreamApiCallError = @@ -501,7 +501,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; AssertEqualErrors(liveStreamApiCallError, expectedLiveStreamApiCallError); NSError *videoApiCallError; - XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]); + XCTAssertFalse([imageClassifier classifyVideoFrame:image + timestampInMilliseconds:0 + error:&videoApiCallError]); NSError *expectedVideoApiCallError = [NSError errorWithDomain:kExpectedErrorDomain @@ -524,7 +526,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; NSError *liveStreamApiCallError; XCTAssertFalse([imageClassifier classifyAsyncImage:image - timestampMs:0 + timestampInMilliseconds:0 error:&liveStreamApiCallError]); NSError *expectedLiveStreamApiCallError = @@ -575,7 +577,9 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; AssertEqualErrors(imageApiCallError, expectedImageApiCallError); NSError *videoApiCallError; - XCTAssertFalse([imageClassifier classifyVideoFrame:image timestampMs:0 error:&videoApiCallError]); + XCTAssertFalse([imageClassifier classifyVideoFrame:image + timestampInMilliseconds:0 + error:&videoApiCallError]); NSError *expectedVideoApiCallError = [NSError errorWithDomain:kExpectedErrorDomain @@ -601,7 +605,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; for (int i = 0; i < 3; i++) { MPPImageClassifierResult *imageClassifierResult = [imageClassifier classifyVideoFrame:image - timestampMs:i + timestampInMilliseconds:i error:nil]; [self assertImageClassifierResult:imageClassifierResult hasExpectedCategoriesCount:maxResults @@ -630,10 +634,10 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; MPPImage *image = [self imageWithFileInfo:kBurgerImage]; - XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:1 error:nil]); + XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:1 error:nil]); NSError *error; - XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampMs:0 error:&error]); + XCTAssertFalse([imageClassifier classifyAsyncImage:image timestampInMilliseconds:0 error:&error]); NSError *expectedError = [NSError errorWithDomain:kExpectedErrorDomain @@ -668,7 +672,7 @@ static NSString *const kExpectedErrorDomain = @"com.google.mediapipe.tasks"; MPPImage *image = [self imageWithFileInfo:kBurgerImage]; for (int i = 0; i < 3; i++) { - XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampMs:i error:nil]); + XCTAssertTrue([imageClassifier classifyAsyncImage:image timestampInMilliseconds:i error:nil]); } } diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h index 6744a8e16..9ce7fcec2 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.h @@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextClassifierResult) * * @param classificationResult The `MPPClassificationResult` instance containing one set of results * per classifier head. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * * @return An instance of `MPPTextClassifierResult` initialized with the given * `MPPClassificationResult` and timestamp (in milliseconds). */ - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m index 4d5c1104a..09a2097cc 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifierResult.m @@ -17,8 +17,8 @@ @implementation MPPTextClassifierResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _classificationResult = classificationResult; } diff --git a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm index f5d6aa1d3..5a924016e 100644 --- a/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.mm @@ -35,7 +35,7 @@ using ::mediapipe::Packet; return [[MPPTextClassifierResult alloc] initWithClassificationResult:classificationResult - timestampMs:(NSInteger)(packet.Timestamp().Value() / + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; } diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h index e4697dcef..ab8edd16b 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.h @@ -31,13 +31,13 @@ NS_SWIFT_NAME(TextEmbedderResult) * * @param embeddingResult The `MPPEmbeddingResult` instance containing one set of results * per classifier head. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in millisecondss) for this result. * * @return An instance of `MPPTextEmbedderResult` initialized with the given * `MPPEmbeddingResult` and timestamp (in milliseconds). */ - (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m index 5483e3c3f..d764f63d6 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedderResult.m @@ -17,8 +17,8 @@ @implementation MPPTextEmbedderResult - (instancetype)initWithEmbeddingResult:(MPPEmbeddingResult *)embeddingResult - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _embeddingResult = embeddingResult; } diff --git a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm index b769292ce..3534ea66d 100644 --- a/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm +++ b/mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.mm @@ -34,7 +34,7 @@ using ::mediapipe::Packet; return [[MPPTextEmbedderResult alloc] initWithEmbeddingResult:embeddingResult - timestampMs:(NSInteger)(packet.Timestamp().Value() / + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; } diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h index eaf059ad2..ed07c6d90 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h @@ -41,7 +41,7 @@ * timestamp. * * @param image The image to send to the MediaPipe graph. - * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet. * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no * error will be saved. * @@ -49,7 +49,7 @@ * occurred during the conversion. */ + (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error; /** @@ -66,11 +66,11 @@ * specified timestamp. * * @param image The `NormalizedRect` to send to the MediaPipe graph. - * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * @param timestampInMilliseconds The timestamp (in milliseconds) to assign to the packet. * * @return The MediaPipe packet containing the normalized rect. */ + (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm index bf136a759..af419c6d0 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm @@ -42,7 +42,7 @@ using ::mediapipe::Timestamp; } + (Packet)createPacketWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { std::unique_ptr imageFrame = [image imageFrameWithError:error]; @@ -51,7 +51,7 @@ using ::mediapipe::Timestamp; } return MakePacket(std::move(imageFrame)) - .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); + .At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond))); } + (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect { @@ -59,9 +59,9 @@ using ::mediapipe::Timestamp; } + (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect - timestampMs:(NSInteger)timestampMs { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { return MakePacket(std::move(normalizedRect)) - .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); + .At(Timestamp(int64(timestampInMilliseconds * kMicroSecondsPerMilliSecond))); } @end diff --git a/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m b/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m index 1fa1a9d37..3ffb15392 100644 --- a/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m +++ b/mediapipe/tasks/ios/vision/gesture_recognizer/sources/MPPGestureRecognizerResult.m @@ -21,7 +21,7 @@ handedness:(NSArray *> *)handedness gestures:(NSArray *> *)gestures timestampInMilliseconds:(NSInteger)timestampInMilliseconds { - self = [super initWithTimestampMs:timestampInMilliseconds]; + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _landmarks = landmarks; _worldLandmarks = worldLandmarks; diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h index 581c8d95b..345687877 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h @@ -122,17 +122,17 @@ NS_SWIFT_NAME(ImageClassifier) * `MPPRunningModeVideo`. * * @param image The `MPPImage` on which image classification is to be performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing image * classification on the input video frame. * * @return An `MPPImageClassifierResult` object that contains a list of image classifications. */ - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error - NS_SWIFT_NAME(classify(videoFrame:timestampMs:)); + NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:)); /** * Performs image classification on the provided video frame of type `MPPImage` cropped to the @@ -145,8 +145,8 @@ NS_SWIFT_NAME(ImageClassifier) * * @param image A live stream image data of type `MPPImage` on which image classification is to be * performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the video frame of type * `MPPImage`, on which image classification should be performed. * @param error An optional error parameter populated when there is an error in performing image @@ -155,10 +155,10 @@ NS_SWIFT_NAME(ImageClassifier) * @return An `MPPImageClassifierResult` object that contains a list of image classifications. */ - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error - NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:)); + NS_SWIFT_NAME(classify(videoFrame:timestampInMilliseconds:regionOfInterest:)); /** * Sends live stream image data of type `MPPImage` to perform image classification using the whole @@ -172,16 +172,17 @@ NS_SWIFT_NAME(ImageClassifier) * * @param image A live stream image data of type `MPPImage` on which image classification is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the image classifier. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the image classifier. The input timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing image * classification on the input live stream image data. * * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:)); /** * Sends live stream image data of type `MPPImage` to perform image classification, cropped to the @@ -195,8 +196,8 @@ NS_SWIFT_NAME(ImageClassifier) * * @param image A live stream image data of type `MPPImage` on which image classification is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the image classifier. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the image classifier. The input timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the given live stream image data * of type `MPPImage`, on which image classification should be performed. * @param error An optional error parameter populated when there is an error in performing image @@ -205,10 +206,10 @@ NS_SWIFT_NAME(ImageClassifier) * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error - NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classifyAsync(image:timestampInMilliseconds:regionOfInterest:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm index 8051fbf3d..18c1bb56a 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -149,7 +149,7 @@ static NSString *const kTaskGraphName = } - (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional rect = @@ -162,14 +162,15 @@ static NSString *const kTaskGraphName = } Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds error:error]; if (imagePacket.IsEmpty()) { return std::nullopt; } - Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampMs:timestampMs]; + Packet normalizedRectPacket = + [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampInMilliseconds:timestampInMilliseconds]; PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); return inputPacketMap; @@ -180,11 +181,11 @@ static NSString *const kTaskGraphName = } - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -204,20 +205,20 @@ static NSString *const kTaskGraphName = } - (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { return [self classifyVideoFrame:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -228,10 +229,10 @@ static NSString *const kTaskGraphName = } - (BOOL)classifyAsyncImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error { return [self classifyAsyncImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h index 92fdb13cb..478bd452a 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h @@ -31,13 +31,13 @@ NS_SWIFT_NAME(ImageClassifierResult) * * @param classificationResult The `MPPClassificationResult` instance containing one set of results * per classifier head. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * * @return An instance of `MPPImageClassifierResult` initialized with the given * `MPPClassificationResult` and timestamp (in milliseconds). */ - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m index 6dcd064eb..cb17bb10e 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.m @@ -17,8 +17,8 @@ @implementation MPPImageClassifierResult - (instancetype)initWithClassificationResult:(MPPClassificationResult *)classificationResult - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _classificationResult = classificationResult; } diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm index 09e21b278..f5199765d 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm @@ -34,7 +34,7 @@ using ::mediapipe::Packet; return [[MPPImageClassifierResult alloc] initWithClassificationResult:classificationResult - timestampMs:(NSInteger)(packet.Timestamp().Value() / + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; } diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h index 590867bf8..da9899d40 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.h @@ -36,13 +36,13 @@ NS_SWIFT_NAME(ObjectDetectionResult) * @param detections An array of `MPPDetection` objects each of which has a bounding box that is * expressed in the unrotated input frame of reference coordinates system, i.e. in `[0,image_width) * x [0,image_height)`, which are the dimensions of the underlying image data. - * @param timestampMs The timestamp for this result. + * @param timestampInMilliseconds The timestamp (in milliseconds) for this result. * * @return An instance of `MPPObjectDetectionResult` initialized with the given array of detections * and timestamp (in milliseconds). */ - (instancetype)initWithDetections:(NSArray *)detections - timestampMs:(NSInteger)timestampMs; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds; @end diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m index ac24c19fa..47902bba4 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetectionResult.m @@ -17,8 +17,8 @@ @implementation MPPObjectDetectionResult - (instancetype)initWithDetections:(NSArray *)detections - timestampMs:(NSInteger)timestampMs { - self = [super initWithTimestampMs:timestampMs]; + timestampInMilliseconds:(NSInteger)timestampInMilliseconds { + self = [super initWithTimestampInMilliseconds:timestampInMilliseconds]; if (self) { _detections = detections; } diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h index 58344d0c7..f92c90c50 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.h @@ -138,8 +138,8 @@ NS_SWIFT_NAME(ObjectDetector) * `MPPRunningModeVideo`. * * @param image The `MPPImage` on which object detection is to be performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing object * detection on the input image. * @@ -149,9 +149,9 @@ NS_SWIFT_NAME(ObjectDetector) * image data. */ - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error - NS_SWIFT_NAME(detect(videoFrame:timestampMs:)); + NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:)); /** * Performs object detection on the provided video frame of type `MPPImage` cropped to the @@ -164,8 +164,8 @@ NS_SWIFT_NAME(ObjectDetector) * * @param image A live stream image data of type `MPPImage` on which object detection is to be * performed. - * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be - * monotonically increasing. + * @param timestampInMilliseconds The video frame's timestamp (in milliseconds). The input + * timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which * object detection should be performed. * @@ -178,10 +178,10 @@ NS_SWIFT_NAME(ObjectDetector) * image data. */ - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error - NS_SWIFT_NAME(detect(videoFrame:timestampMs:regionOfInterest:)); + NS_SWIFT_NAME(detect(videoFrame:timestampInMilliseconds:regionOfInterest:)); /** * Sends live stream image data of type `MPPImage` to perform object detection using the whole @@ -195,16 +195,17 @@ NS_SWIFT_NAME(ObjectDetector) * * @param image A live stream image data of type `MPPImage` on which object detection is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the object detector. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the object detector. The input timestamps must be monotonically increasing. * @param error An optional error parameter populated when there is an error in performing object * detection on the input live stream image data. * * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error NS_SWIFT_NAME(detectAsync(image:timestampMs:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error + NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:)); /** * Sends live stream image data of type `MPPImage` to perform object detection, cropped to the @@ -218,8 +219,8 @@ NS_SWIFT_NAME(ObjectDetector) * * @param image A live stream image data of type `MPPImage` on which object detection is to be * performed. - * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent - * to the object detector. The input timestamps must be monotonically increasing. + * @param timestampInMilliseconds The timestamp (in milliseconds) which indicates when the input + * image is sent to the object detector. The input timestamps must be monotonically increasing. * @param roi A `CGRect` specifying the region of interest within the given live stream image data * of type `MPPImage`, on which iobject detection should be performed. * @param error An optional error parameter populated when there is an error in performing object @@ -228,10 +229,10 @@ NS_SWIFT_NAME(ObjectDetector) * @return `YES` if the image was sent to the task successfully, otherwise `NO`. */ - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error - NS_SWIFT_NAME(detectAsync(image:timestampMs:regionOfInterest:)); + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(detectAsync(image:timestampInMilliseconds:regionOfInterest:)); - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm index 53dcad4a8..e1aa11e96 100644 --- a/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm +++ b/mediapipe/tasks/ios/vision/object_detector/sources/MPPObjectDetector.mm @@ -157,7 +157,7 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional rect = @@ -170,14 +170,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds error:error]; if (imagePacket.IsEmpty()) { return std::nullopt; } - Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() - timestampMs:timestampMs]; + Packet normalizedRectPacket = + [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampInMilliseconds:timestampInMilliseconds]; PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); return inputPacketMap; @@ -188,11 +189,11 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds regionOfInterest:(CGRect)roi error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -212,20 +213,20 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (nullable MPPObjectDetectionResult *)detectInVideoFrame:(MPPImage *)image - timestampMs:(NSInteger)timestampMs + timestampInMilliseconds:(NSInteger)timestampInMilliseconds error:(NSError **)error { return [self detectInVideoFrame:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - regionOfInterest:(CGRect)roi - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + regionOfInterest:(CGRect)roi + error:(NSError **)error { std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:roi error:error]; if (!inputPacketMap.has_value()) { @@ -236,10 +237,10 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.vision.ObjectDetectorG } - (BOOL)detectAsyncInImage:(MPPImage *)image - timestampMs:(NSInteger)timestampMs - error:(NSError **)error { + timestampInMilliseconds:(NSInteger)timestampInMilliseconds + error:(NSError **)error { return [self detectAsyncInImage:image - timestampMs:timestampMs + timestampInMilliseconds:timestampInMilliseconds regionOfInterest:CGRectZero error:error]; } diff --git a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm index 3507b7d72..225a6993d 100644 --- a/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm +++ b/mediapipe/tasks/ios/vision/object_detector/utils/sources/MPPObjectDetectionResult+Helpers.mm @@ -38,8 +38,9 @@ using ::mediapipe::Packet; } return [[MPPObjectDetectionResult alloc] - initWithDetections:detections - timestampMs:(NSInteger)(packet.Timestamp().Value() / kMicroSecondsPerMilliSecond)]; + initWithDetections:detections + timestampInMilliseconds:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; } @end diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD index 97f8dfd15..8f8abf06f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/BUILD @@ -28,7 +28,7 @@ cc_library_with_tflite( ], tflite_deps = [ "//mediapipe/tasks/cc/core:model_resources_cache", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], deps = [ "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc index aab022dec..c7e015639 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.cc @@ -19,7 +19,7 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" #include "mediapipe/tasks/cc/core/model_resources_cache.h" -#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow/lite/kernels/register.h" namespace { using ::mediapipe::tasks::core::kModelResourcesCacheService; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 641dbd3ba..d63b0e358 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -54,6 +54,9 @@ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", ] _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 3d13974c7..fc933b6f3 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -54,6 +54,7 @@ cc_binary( "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", + "//mediapipe/tasks/cc/vision/pose_landmarker:pose_landmarker_graph", "//mediapipe/tasks/java:version_script.lds", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", ], @@ -174,6 +175,37 @@ android_library( ], ) +android_library( + name = "poselandmarker", + srcs = [ + "poselandmarker/PoseLandmarker.java", + "poselandmarker/PoseLandmarkerResult.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "poselandmarker/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/framework/formats:classification_java_proto_lite", + "//mediapipe/framework/formats:landmark_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_java_proto_lite", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarker_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + android_library( name = "handlandmarker", srcs = [ diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java index b95e9021f..d6f565c78 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizer.java @@ -198,9 +198,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *
  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @throws MediaPipeException if there is an internal error. Or if {@link FaceStylizer} is created @@ -220,9 +220,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the @@ -256,9 +256,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a @@ -281,9 +281,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the @@ -320,9 +320,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). @@ -346,9 +346,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the @@ -387,9 +387,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). @@ -414,9 +414,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *

  • {@link android.graphics.Bitmap.Config#ARGB_8888} * * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param timestampMs the input timestamp (in milliseconds). @@ -445,9 +445,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { * *

    {@link FaceStylizer} supports the following color space types: * - *

    The image can be of any size. To ensure that the output image has reasonable quality, the - * size of the stylized output is based the model output * size and can be smaller than the input - * image. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * *

      *
    • {@link android.graphics.Bitmap.Config#ARGB_8888} @@ -475,9 +475,9 @@ public final class FaceStylizer extends BaseVisionTaskApi { *
    • {@link android.graphics.Bitmap.Config#ARGB_8888} *
    * - *

    The input image can be of any size. To ensure that the output image has reasonable quality, - * the stylized output image size is the smaller of the model output size and the size of the - * {@link ImageProcessingOptions#regionOfInterest} specified in {@code imageProcessingOptions}. + *

    The input image can be of any size. The output image is the stylized image with the most + * visible face. The stylized output image size is the same as the model output size. When no face + * is detected on the input image, returns {@code Optional.empty()}. * * @param image a MediaPipe {@link MPImage} object for processing. * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java index b38bd1c86..6f47797b4 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -94,15 +94,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { "IMAGE:" + IMAGE_IN_STREAM_NAME, "ROI:" + ROI_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); - private static final List OUTPUT_STREAMS = - Collections.unmodifiableList( - Arrays.asList( - "GROUPED_SEGMENTATION:segmented_mask_out", - "IMAGE:image_out", - "SEGMENTATION:0:segmentation")); - private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; - private static final int IMAGE_OUT_STREAM_INDEX = 1; - private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; + private static final int IMAGE_OUT_STREAM_INDEX = 0; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = @@ -123,6 +115,21 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { */ public static InteractiveSegmenter createFromOptions( Context context, InteractiveSegmenterOptions segmenterOptions) { + if (!segmenterOptions.outputConfidenceMasks() && !segmenterOptions.outputCategoryMask()) { + throw new IllegalArgumentException( + "At least one of `outputConfidenceMasks` and `outputCategoryMask` must be set."); + } + List outputStreams = new ArrayList<>(); + outputStreams.add("IMAGE:image_out"); + if (segmenterOptions.outputConfidenceMasks()) { + outputStreams.add("CONFIDENCE_MASKS:confidence_masks"); + } + final int confidenceMasksOutStreamIndex = outputStreams.size() - 1; + if (segmenterOptions.outputCategoryMask()) { + outputStreams.add("CATEGORY_MASK:category_mask"); + } + final int categoryMaskOutStreamIndex = outputStreams.size() - 1; + // TODO: Consolidate OutputHandler and TaskRunner. OutputHandler handler = new OutputHandler<>(); handler.setOutputPacketConverter( @@ -130,52 +137,72 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { @Override public ImageSegmenterResult convertToTaskResult(List packets) throws MediaPipeException { - if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { + if (packets.get(IMAGE_OUT_STREAM_INDEX).isEmpty()) { return ImageSegmenterResult.create( Optional.empty(), Optional.empty(), - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); + packets.get(IMAGE_OUT_STREAM_INDEX).getTimestamp()); } - List segmentedMasks = new ArrayList<>(); - int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); - int imageFormat = - segmenterOptions.outputType() - == InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK - ? MPImage.IMAGE_FORMAT_VEC32F1 - : MPImage.IMAGE_FORMAT_ALPHA; - int imageListSize = - PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); - ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; - // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe - // graph. If provided, the result MPImage is wrapping the mediapipe packet memory. - if (!segmenterOptions.resultListener().isPresent()) { - for (int i = 0; i < imageListSize; i++) { - buffersArray[i] = - ByteBuffer.allocateDirect( - width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); + // If resultListener is not provided, the resulted MPImage is deep copied from + // mediapipe graph. If provided, the result MPImage is wrapping the mediapipe packet + // memory. + boolean copyImage = !segmenterOptions.resultListener().isPresent(); + Optional> confidenceMasks = Optional.empty(); + if (segmenterOptions.outputConfidenceMasks()) { + confidenceMasks = Optional.of(new ArrayList<>()); + int width = + PacketGetter.getImageWidthFromImageList( + packets.get(confidenceMasksOutStreamIndex)); + int height = + PacketGetter.getImageHeightFromImageList( + packets.get(confidenceMasksOutStreamIndex)); + int imageListSize = + PacketGetter.getImageListSize(packets.get(confidenceMasksOutStreamIndex)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + // confidence masks are float type image. + final int numBytes = 4; + if (copyImage) { + for (int i = 0; i < imageListSize; i++) { + buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes); + } + } + if (!PacketGetter.getImageList( + packets.get(confidenceMasksOutStreamIndex), buffersArray, copyImage)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting confidence masks."); + } + for (ByteBuffer buffer : buffersArray) { + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); + confidenceMasks.get().add(builder.build()); } } - if (!PacketGetter.getImageList( - packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), - buffersArray, - !segmenterOptions.resultListener().isPresent())) { - throw new MediaPipeException( - MediaPipeException.StatusCode.INTERNAL.ordinal(), - "There is an error getting segmented masks. It usually results from incorrect" - + " options of unsupported OutputType of given model."); - } - for (ByteBuffer buffer : buffersArray) { + Optional categoryMask = Optional.empty(); + if (segmenterOptions.outputCategoryMask()) { + int width = PacketGetter.getImageWidth(packets.get(categoryMaskOutStreamIndex)); + int height = PacketGetter.getImageHeight(packets.get(categoryMaskOutStreamIndex)); + ByteBuffer buffer; + if (copyImage) { + buffer = ByteBuffer.allocateDirect(width * height); + if (!PacketGetter.getImageData(packets.get(categoryMaskOutStreamIndex), buffer)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting category mask."); + } + } else { + buffer = PacketGetter.getImageDataDirectly(packets.get(categoryMaskOutStreamIndex)); + } ByteBufferImageBuilder builder = - new ByteBufferImageBuilder(buffer, width, height, imageFormat); - segmentedMasks.add(builder.build()); + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_ALPHA); + categoryMask = Optional.of(builder.build()); } return ImageSegmenterResult.create( - Optional.of(segmentedMasks), - Optional.empty(), + confidenceMasks, + categoryMask, BaseVisionTaskApi.generateResultTimestampMs( - RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); + RunningMode.IMAGE, packets.get(IMAGE_OUT_STREAM_INDEX))); } @Override @@ -195,7 +222,7 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { .setTaskRunningModeName(RunningMode.IMAGE.name()) .setTaskGraphName(TASK_GRAPH_NAME) .setInputStreams(INPUT_STREAMS) - .setOutputStreams(OUTPUT_STREAMS) + .setOutputStreams(outputStreams) .setTaskOptions(segmenterOptions) .setEnableFlowLimiting(false) .build(), @@ -394,8 +421,11 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { /** Sets the base options for the image segmenter task. */ public abstract Builder setBaseOptions(BaseOptions value); - /** The output type from image segmenter. */ - public abstract Builder setOutputType(OutputType value); + /** Sets whether to output confidence masks. Default to true. */ + public abstract Builder setOutputConfidenceMasks(boolean value); + + /** Sets whether to output category mask. Default to false. */ + public abstract Builder setOutputCategoryMask(boolean value); /** * Sets an optional {@link ResultListener} to receive the segmentation results when the graph @@ -417,25 +447,18 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { abstract BaseOptions baseOptions(); - abstract OutputType outputType(); + abstract boolean outputConfidenceMasks(); + + abstract boolean outputCategoryMask(); abstract Optional> resultListener(); abstract Optional errorListener(); - /** The output type of segmentation results. */ - public enum OutputType { - // Gives a single output mask where each pixel represents the class which - // the pixel in the original image was predicted to belong to. - CATEGORY_MASK, - // Gives a list of output masks where, for each mask, each pixel represents - // the prediction confidence, usually in the [0, 1] range. - CONFIDENCE_MASK - } - public static Builder builder() { return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder() - .setOutputType(OutputType.CATEGORY_MASK); + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false); } /** @@ -454,14 +477,6 @@ public final class InteractiveSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = SegmenterOptionsProto.SegmenterOptions.newBuilder(); - if (outputType() == OutputType.CONFIDENCE_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); - } else if (outputType() == OutputType.CATEGORY_MASK) { - segmenterOptionsBuilder.setOutputType( - SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); - } - taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/AndroidManifest.xml new file mode 100644 index 000000000..3e5809bd8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java new file mode 100644 index 000000000..2d9aafc4b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarker.java @@ -0,0 +1,557 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.poselandmarker; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto.LandmarkList; +import com.google.mediapipe.formats.proto.LandmarkProto.NormalizedLandmarkList; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.posedetector.proto.PoseDetectorGraphOptionsProto; +import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarkerGraphOptionsProto; +import com.google.mediapipe.tasks.vision.poselandmarker.proto.PoseLandmarksDetectorGraphOptionsProto; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs pose landmarks detection on images. + * + *

    This API expects a pre-trained pose landmarks model asset bundle. See . + * + *

      + *
    • Input image {@link MPImage} + *
        + *
      • The image that pose landmarks detection runs on. + *
      + *
    • Output PoseLandmarkerResult {@link PoseLandmarkerResult} + *
        + *
      • A PoseLandmarkerResult containing pose landmarks. + *
      + *
    + */ +public final class PoseLandmarker extends BaseVisionTaskApi { + private static final String TAG = PoseLandmarker.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + + private static final int LANDMARKS_OUT_STREAM_INDEX = 0; + private static final int WORLD_LANDMARKS_OUT_STREAM_INDEX = 1; + private static final int AUXILIARY_LANDMARKS_OUT_STREAM_INDEX = 2; + private static final int IMAGE_OUT_STREAM_INDEX = 3; + private static int segmentationMasksOutStreamIndex = -1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph"; + + /** + * Creates a {@link PoseLandmarker} instance from a model file and the default {@link + * PoseLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the pose landmarks model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. + */ + public static PoseLandmarker createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link PoseLandmarker} instance from a model file and the default {@link + * PoseLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the pose landmarks model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. + */ + public static PoseLandmarker createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link PoseLandmarker} instance from a model buffer and the default {@link + * PoseLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model. + * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. + */ + public static PoseLandmarker createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, PoseLandmarkerOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link PoseLandmarker} instance from a {@link PoseLandmarkerOptions}. + * + * @param context an Android {@link Context}. + * @param landmarkerOptions a {@link PoseLandmarkerOptions} instance. + * @throws MediaPipeException if there is an error during {@link PoseLandmarker} creation. + */ + public static PoseLandmarker createFromOptions( + Context context, PoseLandmarkerOptions landmarkerOptions) { + List outputStreams = new ArrayList<>(); + outputStreams.add("NORM_LANDMARKS:pose_landmarks"); + outputStreams.add("WORLD_LANDMARKS:world_landmarks"); + outputStreams.add("AUXILIARY_LANDMARKS:auxiliary_landmarks"); + outputStreams.add("IMAGE:image_out"); + if (landmarkerOptions.outputSegmentationMasks()) { + outputStreams.add("SEGMENTATION_MASK:segmentation_masks"); + segmentationMasksOutStreamIndex = outputStreams.size() - 1; + } + + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public PoseLandmarkerResult convertToTaskResult(List packets) { + // If there is no poses detected in the image, just returns empty lists. + if (packets.get(LANDMARKS_OUT_STREAM_INDEX).isEmpty()) { + return PoseLandmarkerResult.create( + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + Optional.empty(), + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); + } + /** Get segmentation masks */ + Optional> segmentedMasks = Optional.empty(); + if (landmarkerOptions.outputSegmentationMasks()) { + segmentedMasks = getSegmentationMasks(packets); + } + + return PoseLandmarkerResult.create( + PacketGetter.getProtoVector( + packets.get(LANDMARKS_OUT_STREAM_INDEX), NormalizedLandmarkList.parser()), + PacketGetter.getProtoVector( + packets.get(WORLD_LANDMARKS_OUT_STREAM_INDEX), LandmarkList.parser()), + PacketGetter.getProtoVector( + packets.get(AUXILIARY_LANDMARKS_OUT_STREAM_INDEX), + NormalizedLandmarkList.parser()), + segmentedMasks, + BaseVisionTaskApi.generateResultTimestampMs( + landmarkerOptions.runningMode(), packets.get(LANDMARKS_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + landmarkerOptions.resultListener().ifPresent(handler::setResultListener); + landmarkerOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(PoseLandmarker.class.getSimpleName()) + .setTaskRunningModeName(landmarkerOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(outputStreams) + .setTaskOptions(landmarkerOptions) + .setEnableFlowLimiting(landmarkerOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new PoseLandmarker(runner, landmarkerOptions.runningMode()); + } + + /** + * Constructor to initialize a {@link PoseLandmarker} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private PoseLandmarker(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs pose landmarks detection on the provided single image with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * PoseLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java doc + * for input image format. + * + *

    {@link PoseLandmarker} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public PoseLandmarkerResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs pose landmarks detection on the provided single image. Only use this method when the + * {@link PoseLandmarker} is created with {@link RunningMode.IMAGE}. TODO update java + * doc for input image format. + * + *

    {@link PoseLandmarker} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public PoseLandmarkerResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (PoseLandmarkerResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs pose landmarks detection on the provided video frame with default image processing + * options, i.e. without any rotation applied. Only use this method when the {@link + * PoseLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link PoseLandmarker} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public PoseLandmarkerResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs pose landmarks detection on the provided video frame. Only use this method when the + * {@link PoseLandmarker} is created with {@link RunningMode.VIDEO}. + * + *

    It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

    {@link PoseLandmarker} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public PoseLandmarkerResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (PoseLandmarkerResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform pose landmarks detection with default image processing + * options, i.e. without any rotation applied, and the results will be available via the {@link + * ResultListener} provided in the {@link PoseLandmarkerOptions}. Only use this method when the + * {@link PoseLandmarker } is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the pose landmarker. The input timestamps must be monotonically increasing. + * + *

    {@link PoseLandmarker} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform pose landmarks detection, and the results will be available + * via the {@link ResultListener} provided in the {@link PoseLandmarkerOptions}. Only use this + * method when the {@link PoseLandmarker} is created with {@link RunningMode.LIVE_STREAM}. + * + *

    It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the pose landmarker. The input timestamps must be monotonically increasing. + * + *

    {@link PoseLandmarker} supports the following color space types: + * + *

      + *
    • {@link Bitmap.Config.ARGB_8888} + *
    + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up an {@link PoseLandmarker}. */ + @AutoValue + public abstract static class PoseLandmarkerOptions extends TaskOptions { + + /** Builder for {@link PoseLandmarkerOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the pose landmarker task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the running mode for the pose landmarker task. Default to the image mode. Pose + * landmarker has three modes: + * + *
      + *
    • IMAGE: The mode for detecting pose landmarks on single image inputs. + *
    • VIDEO: The mode for detecting pose landmarks on the decoded frames of a video. + *
    • LIVE_STREAM: The mode for for detecting pose landmarks on a live stream of input + * data, such as from camera. In this mode, {@code setResultListener} must be called to + * set up a listener to receive the detection results asynchronously. + *
    + */ + public abstract Builder setRunningMode(RunningMode value); + + /** Sets the maximum number of poses can be detected by the PoseLandmarker. */ + public abstract Builder setNumPoses(Integer value); + + /** Sets minimum confidence score for the pose detection to be considered successful */ + public abstract Builder setMinPoseDetectionConfidence(Float value); + + /** Sets minimum confidence score of pose presence score in the pose landmark detection. */ + public abstract Builder setMinPosePresenceConfidence(Float value); + + /** Sets the minimum confidence score for the pose tracking to be considered successful. */ + public abstract Builder setMinTrackingConfidence(Float value); + + public abstract Builder setOutputSegmentationMasks(Boolean value); + + /** + * Sets the result listener to receive the detection results asynchronously when the pose + * landmarker is in the live stream mode. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional error listener. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract PoseLandmarkerOptions autoBuild(); + + /** + * Validates and builds the {@link PoseLandmarkerOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the pose landmarker is + * in the live stream mode. + */ + public final PoseLandmarkerOptions build() { + PoseLandmarkerOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The pose landmarker is in the live stream mode, a user-defined result listener" + + " must be provided in PoseLandmarkerOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The pose landmarker is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in PoseLandmarkerOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract Optional numPoses(); + + abstract Optional minPoseDetectionConfidence(); + + abstract Optional minPosePresenceConfidence(); + + abstract Optional minTrackingConfidence(); + + abstract Boolean outputSegmentationMasks(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_PoseLandmarker_PoseLandmarkerOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setNumPoses(1) + .setMinPoseDetectionConfidence(0.5f) + .setMinPosePresenceConfidence(0.5f) + .setMinTrackingConfidence(0.5f) + .setOutputSegmentationMasks(false); + } + + /** Converts a {@link PoseLandmarkerOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.Builder taskOptionsBuilder = + PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(runningMode() != RunningMode.IMAGE) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()); + + // Setup PoseDetectorGraphOptions. + PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions.Builder + poseDetectorGraphOptionsBuilder = + PoseDetectorGraphOptionsProto.PoseDetectorGraphOptions.newBuilder(); + numPoses().ifPresent(poseDetectorGraphOptionsBuilder::setNumPoses); + minPoseDetectionConfidence() + .ifPresent(poseDetectorGraphOptionsBuilder::setMinDetectionConfidence); + + // Setup PoseLandmarkerGraphOptions. + PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions.Builder + poseLandmarksDetectorGraphOptionsBuilder = + PoseLandmarksDetectorGraphOptionsProto.PoseLandmarksDetectorGraphOptions.newBuilder(); + minPosePresenceConfidence() + .ifPresent(poseLandmarksDetectorGraphOptionsBuilder::setMinDetectionConfidence); + minTrackingConfidence().ifPresent(taskOptionsBuilder::setMinTrackingConfidence); + + taskOptionsBuilder + .setPoseDetectorGraphOptions(poseDetectorGraphOptionsBuilder.build()) + .setPoseLandmarksDetectorGraphOptions(poseLandmarksDetectorGraphOptionsBuilder.build()); + + return CalculatorOptions.newBuilder() + .setExtension( + PoseLandmarkerGraphOptionsProto.PoseLandmarkerGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("PoseLandmarker doesn't support region-of-interest."); + } + } + + private static Optional> getSegmentationMasks(List packets) { + Optional> segmentedMasks = Optional.of(new ArrayList<>()); + int width = + PacketGetter.getImageWidthFromImageList(packets.get(segmentationMasksOutStreamIndex)); + int height = + PacketGetter.getImageHeightFromImageList(packets.get(segmentationMasksOutStreamIndex)); + int imageListSize = PacketGetter.getImageListSize(packets.get(segmentationMasksOutStreamIndex)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + + // Segmentation mask is a float type image. + int numBytes = 4; + for (int i = 0; i < imageListSize; i++) { + buffersArray[i] = ByteBuffer.allocateDirect(width * height * numBytes); + } + + if (!PacketGetter.getImageList( + packets.get(segmentationMasksOutStreamIndex), + buffersArray, + /** deepCopy= */ + true)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting segmented masks."); + } + for (ByteBuffer buffer : buffersArray) { + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, MPImage.IMAGE_FORMAT_VEC32F1); + segmentedMasks.get().add(builder.build()); + } + return segmentedMasks; + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java new file mode 100644 index 000000000..bb632d3b8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerResult.java @@ -0,0 +1,113 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.poselandmarker; + +import com.google.auto.value.AutoValue; +import com.google.mediapipe.formats.proto.LandmarkProto; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.Landmark; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; +import com.google.mediapipe.tasks.core.TaskResult; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** Represents the pose landmarks deection results generated by {@link PoseLandmarker}. */ +@AutoValue +public abstract class PoseLandmarkerResult implements TaskResult { + + /** + * Creates a {@link PoseLandmarkerResult} instance from the lists of landmarks and + * segmentationMask protobuf messages. + * + * @param landmarksProto a List of {@link NormalizedLandmarkList} + * @param worldLandmarksProto a List of {@link LandmarkList} + * @param segmentationMasksData a List of {@link MPImage} + */ + static PoseLandmarkerResult create( + List landmarksProto, + List worldLandmarksProto, + List auxiliaryLandmarksProto, + Optional> segmentationMasksData, + long timestampMs) { + + Optional> multiPoseSegmentationMasks = Optional.empty(); + if (segmentationMasksData.isPresent()) { + multiPoseSegmentationMasks = + Optional.of(Collections.unmodifiableList(segmentationMasksData.get())); + } + + List> multiPoseLandmarks = new ArrayList<>(); + List> multiPoseWorldLandmarks = new ArrayList<>(); + List> multiPoseAuxiliaryLandmarks = new ArrayList<>(); + for (LandmarkProto.NormalizedLandmarkList poseLandmarksProto : landmarksProto) { + List poseLandmarks = new ArrayList<>(); + multiPoseLandmarks.add(poseLandmarks); + for (LandmarkProto.NormalizedLandmark poseLandmarkProto : + poseLandmarksProto.getLandmarkList()) { + poseLandmarks.add( + NormalizedLandmark.create( + poseLandmarkProto.getX(), poseLandmarkProto.getY(), poseLandmarkProto.getZ())); + } + } + for (LandmarkProto.LandmarkList poseWorldLandmarksProto : worldLandmarksProto) { + List poseWorldLandmarks = new ArrayList<>(); + multiPoseWorldLandmarks.add(poseWorldLandmarks); + for (LandmarkProto.Landmark poseWorldLandmarkProto : + poseWorldLandmarksProto.getLandmarkList()) { + poseWorldLandmarks.add( + Landmark.create( + poseWorldLandmarkProto.getX(), + poseWorldLandmarkProto.getY(), + poseWorldLandmarkProto.getZ())); + } + } + for (LandmarkProto.NormalizedLandmarkList poseAuxiliaryLandmarksProto : + auxiliaryLandmarksProto) { + List poseAuxiliaryLandmarks = new ArrayList<>(); + multiPoseAuxiliaryLandmarks.add(poseAuxiliaryLandmarks); + for (LandmarkProto.NormalizedLandmark poseAuxiliaryLandmarkProto : + poseAuxiliaryLandmarksProto.getLandmarkList()) { + poseAuxiliaryLandmarks.add( + NormalizedLandmark.create( + poseAuxiliaryLandmarkProto.getX(), + poseAuxiliaryLandmarkProto.getY(), + poseAuxiliaryLandmarkProto.getZ())); + } + } + return new AutoValue_PoseLandmarkerResult( + timestampMs, + Collections.unmodifiableList(multiPoseLandmarks), + Collections.unmodifiableList(multiPoseWorldLandmarks), + Collections.unmodifiableList(multiPoseAuxiliaryLandmarks), + multiPoseSegmentationMasks); + } + + @Override + public abstract long timestampMs(); + + /** Pose landmarks of detected poses. */ + public abstract List> landmarks(); + + /** Pose landmarks in world coordniates of detected poses. */ + public abstract List> worldLandmarks(); + + /** Pose auxiliary landmarks. */ + public abstract List> auxiliaryLandmarks(); + + /** Pose segmentation masks. */ + public abstract Optional> segmentationMasks(); +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java index 5b880f419..4f6cc2d68 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facestylizer/FaceStylizerTest.java @@ -234,8 +234,8 @@ public class FaceStylizerTest { FaceStylizerResult actualResult = faceStylizer.stylize(inputImage); MPImage stylizedImage = actualResult.stylizedImage().get(); assertThat(stylizedImage).isNotNull(); - assertThat(stylizedImage.getWidth()).isEqualTo(83); - assertThat(stylizedImage.getHeight()).isEqualTo(83); + assertThat(stylizedImage.getWidth()).isEqualTo(modelImageSize); + assertThat(stylizedImage.getHeight()).isEqualTo(modelImageSize); } @Test diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java index f32ab7976..3a6854949 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -53,18 +53,15 @@ public class InteractiveSegmenterTest { InteractiveSegmenterOptions options = InteractiveSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK) + .setOutputConfidenceMasks(false) + .setOutputCategoryMask(true) .build(); InteractiveSegmenter imageSegmenter = InteractiveSegmenter.createFromOptions( ApplicationProvider.getApplicationContext(), options); MPImage image = getImageFromAsset(inputImageName); ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); - // TODO update to correct category mask output. - // After InteractiveSegmenter updated according to (b/276519300), update this to use - // categoryMask field instead of confidenceMasks. - List segmentations = actualResult.confidenceMasks().get(); - assertThat(segmentations.size()).isEqualTo(1); + assertThat(actualResult.categoryMask().isPresent()).isTrue(); } @Test @@ -75,15 +72,17 @@ public class InteractiveSegmenterTest { InteractiveSegmenterOptions options = InteractiveSegmenterOptions.builder() .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) - .setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK) + .setOutputConfidenceMasks(true) + .setOutputCategoryMask(false) .build(); InteractiveSegmenter imageSegmenter = InteractiveSegmenter.createFromOptions( ApplicationProvider.getApplicationContext(), options); ImageSegmenterResult actualResult = imageSegmenter.segment(getImageFromAsset(inputImageName), roi); - List segmentations = actualResult.confidenceMasks().get(); - assertThat(segmentations.size()).isEqualTo(2); + assertThat(actualResult.confidenceMasks().isPresent()).isTrue(); + List confidenceMasks = actualResult.confidenceMasks().get(); + assertThat(confidenceMasks.size()).isEqualTo(2); } } diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/AndroidManifest.xml new file mode 100644 index 000000000..7c17b77d3 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/BUILD new file mode 100644 index 000000000..c14486766 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java new file mode 100644 index 000000000..30ced66f5 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/poselandmarker/PoseLandmarkerTest.java @@ -0,0 +1,365 @@ +// Copyright 2022 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.poselandmarker; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.truth.Correspondence; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.NormalizedLandmark; +import com.google.mediapipe.tasks.components.containers.proto.LandmarksDetectionResultProto.LandmarksDetectionResult; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.poselandmarker.PoseLandmarker.PoseLandmarkerOptions; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link PoseLandmarker}. */ +@RunWith(Suite.class) +@SuiteClasses({PoseLandmarkerTest.General.class, PoseLandmarkerTest.RunningModeTest.class}) +public class PoseLandmarkerTest { + private static final String POSE_LANDMARKER_BUNDLE_ASSET_FILE = "pose_landmarker.task"; + private static final String POSE_IMAGE = "pose.jpg"; + private static final String POSE_LANDMARKS = "pose_landmarks.pb"; + private static final String NO_POSES_IMAGE = "burger.jpg"; + private static final String TAG = "Pose Landmarker Test"; + private static final float LANDMARKS_ERROR_TOLERANCE = 0.03f; + private static final int IMAGE_WIDTH = 1000; + private static final int IMAGE_HEIGHT = 667; + + @RunWith(AndroidJUnit4.class) + public static final class General extends PoseLandmarkerTest { + + @Test + public void detect_successWithValidModels() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void detect_successWithEmptyResult() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .build(); + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(NO_POSES_IMAGE)); + assertThat(actualResult.landmarks()).isEmpty(); + assertThat(actualResult.worldLandmarks()).isEmpty(); + // TODO: Add additional tests for MP Tasks Pose Graphs + // Add tests for segmentation masks. + } + + @Test + public void recognize_failsWithRegionOfInterest() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setNumPoses(1) + .build(); + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("PoseLandmarker doesn't support region-of-interest"); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends PoseLandmarkerTest { + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(mode) + .setResultListener((PoseLandmarkerResults, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE) + .build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void recognize_failsWithCallingWrongApiInImageMode() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + poseLandmarker.detectForVideo( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> poseLandmarker.detectAsync(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void recognize_failsWithCallingWrongApiInVideoMode() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> poseLandmarker.detectAsync(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void recognize_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((PoseLandmarkerResults, inputImage) -> {}) + .build(); + + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, () -> poseLandmarker.detect(getImageFromAsset(POSE_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + poseLandmarker.detectForVideo( + getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void recognize_successWithImageMode() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + PoseLandmarkerResult actualResult = poseLandmarker.detect(getImageFromAsset(POSE_IMAGE)); + PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + + @Test + public void recognize_successWithVideoMode() throws Exception { + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options); + PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); + for (int i = 0; i < 3; i++) { + PoseLandmarkerResult actualResult = + poseLandmarker.detectForVideo(getImageFromAsset(POSE_IMAGE), /* timestampsMs= */ i); + assertActualResultApproximatelyEqualsToExpectedResult(actualResult, expectedResult); + } + } + + @Test + public void recognize_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(POSE_IMAGE); + PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + poseLandmarker.detectAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> poseLandmarker.detectAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void recognize_successWithLiveSteamMode() throws Exception { + MPImage image = getImageFromAsset(POSE_IMAGE); + PoseLandmarkerResult expectedResult = getExpectedPoseLandmarkerResult(POSE_LANDMARKS); + PoseLandmarkerOptions options = + PoseLandmarkerOptions.builder() + .setBaseOptions( + BaseOptions.builder().setModelAssetPath(POSE_LANDMARKER_BUNDLE_ASSET_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (actualResult, inputImage) -> { + assertActualResultApproximatelyEqualsToExpectedResult( + actualResult, expectedResult); + assertImageSizeIsExpected(inputImage); + }) + .build(); + try (PoseLandmarker poseLandmarker = + PoseLandmarker.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + poseLandmarker.detectAsync(image, /* timestampsMs= */ i); + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static PoseLandmarkerResult getExpectedPoseLandmarkerResult(String filePath) + throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + LandmarksDetectionResult landmarksDetectionResultProto = + LandmarksDetectionResult.parser().parseFrom(istr); + return PoseLandmarkerResult.create( + Arrays.asList(landmarksDetectionResultProto.getLandmarks()), + Arrays.asList(landmarksDetectionResultProto.getWorldLandmarks()), + Arrays.asList(), + Optional.empty(), + /* timestampMs= */ 0); + } + + private static void assertActualResultApproximatelyEqualsToExpectedResult( + PoseLandmarkerResult actualResult, PoseLandmarkerResult expectedResult) { + // TODO: Add additional tests for MP Tasks Pose Graphs + // Add additional tests for auxiliary, world landmarks and segmentation masks. + // Expects to have the same number of poses detected. + assertThat(actualResult.landmarks()).hasSize(expectedResult.landmarks().size()); + + // Actual landmarks match expected landmarks. + assertThat(actualResult.landmarks().get(0)) + .comparingElementsUsing( + Correspondence.from( + (Correspondence.BinaryPredicate) + (actual, expected) -> { + return Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) + .compare(actual.x(), expected.x()) + && Correspondence.tolerance(LANDMARKS_ERROR_TOLERANCE) + .compare(actual.y(), expected.y()); + }, + "landmarks approximately equal to")) + .containsExactlyElementsIn(expectedResult.landmarks().get(0)); + } + + private static void assertImageSizeIsExpected(MPImage inputImage) { + assertThat(inputImage).isNotNull(); + assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH); + assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT); + } +} diff --git a/mediapipe/tasks/python/core/pybind/task_runner.cc b/mediapipe/tasks/python/core/pybind/task_runner.cc index aa48a1a9a..250c0fa62 100644 --- a/mediapipe/tasks/python/core/pybind/task_runner.cc +++ b/mediapipe/tasks/python/core/pybind/task_runner.cc @@ -204,6 +204,11 @@ This can be useful for resetting a stateful task graph to process new data. Raises: RuntimeError: The underlying medipaipe graph fails to reset and restart. )doc"); + + task_runner.def( + "get_graph_config", + [](TaskRunner* self) { return self->GetGraphConfig(); }, + R"doc(Returns the canonicalized CalculatorGraphConfig of the underlying graph.)doc"); } } // namespace python diff --git a/mediapipe/tasks/python/test/text/text_embedder_test.py b/mediapipe/tasks/python/test/text/text_embedder_test.py index 78e98a1b4..62d162f6e 100644 --- a/mediapipe/tasks/python/test/text/text_embedder_test.py +++ b/mediapipe/tasks/python/test/text/text_embedder_test.py @@ -32,6 +32,7 @@ _TextEmbedderOptions = text_embedder.TextEmbedderOptions _BERT_MODEL_FILE = 'mobilebert_embedding_with_metadata.tflite' _REGEX_MODEL_FILE = 'regex_one_embedding_with_metadata.tflite' +_USE_MODEL_FILE = 'universal_sentence_encoder_qa_with_metadata.tflite' _TEST_DATA_DIR = 'mediapipe/tasks/testdata/text' # Tolerance for embedding vector coordinate values. _EPSILON = 1e-4 @@ -138,6 +139,24 @@ class TextEmbedderTest(parameterized.TestCase): 16, (0.549632, 0.552879), ), + ( + False, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.851961, + 100, + (1.422951, 1.404664), + ), + ( + True, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.851961, + 100, + (0.127049, 0.125416), + ), ) def test_embed(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, expected_first_values): @@ -213,6 +232,24 @@ class TextEmbedderTest(parameterized.TestCase): 16, (0.549632, 0.552879), ), + ( + False, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_NAME, + 0.851961, + 100, + (1.422951, 1.404664), + ), + ( + True, + False, + _USE_MODEL_FILE, + ModelFileType.FILE_CONTENT, + 0.851961, + 100, + (0.127049, 0.125416), + ), ) def test_embed_in_context(self, l2_normalize, quantize, model_name, model_file_type, expected_similarity, expected_size, @@ -251,6 +288,7 @@ class TextEmbedderTest(parameterized.TestCase): @parameterized.parameters( # TODO: The similarity should likely be lower (_BERT_MODEL_FILE, 0.980880), + (_USE_MODEL_FILE, 0.780334), ) def test_embed_with_different_themes(self, model_file, expected_similarity): # Creates embedder. diff --git a/mediapipe/tasks/python/test/vision/image_segmenter_test.py b/mediapipe/tasks/python/test/vision/image_segmenter_test.py index 7f0b47eb7..1a534c98d 100644 --- a/mediapipe/tasks/python/test/vision/image_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/image_segmenter_test.py @@ -15,7 +15,6 @@ import enum import os -from typing import List from unittest import mock from absl.testing import absltest @@ -30,11 +29,10 @@ from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import image_segmenter from mediapipe.tasks.python.vision.core import vision_task_running_mode +ImageSegmenterResult = image_segmenter.ImageSegmenterResult _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat -_OutputType = image_segmenter.ImageSegmenterOptions.OutputType -_Activation = image_segmenter.ImageSegmenterOptions.Activation _ImageSegmenter = image_segmenter.ImageSegmenter _ImageSegmenterOptions = image_segmenter.ImageSegmenterOptions _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode @@ -42,9 +40,54 @@ _RUNNING_MODE = vision_task_running_mode.VisionTaskRunningMode _MODEL_FILE = 'deeplabv3.tflite' _IMAGE_FILE = 'segmentation_input_rotation0.jpg' _SEGMENTATION_FILE = 'segmentation_golden_rotation0.png' +_CAT_IMAGE = 'cat.jpg' +_CAT_MASK = 'cat_mask.jpg' _MASK_MAGNIFICATION_FACTOR = 10 _MASK_SIMILARITY_THRESHOLD = 0.98 _TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' +_EXPECTED_LABELS = [ + 'background', + 'aeroplane', + 'bicycle', + 'bird', + 'boat', + 'bottle', + 'bus', + 'car', + 'cat', + 'chair', + 'cow', + 'dining table', + 'dog', + 'horse', + 'motorbike', + 'person', + 'potted plant', + 'sheep', + 'sofa', + 'train', + 'tv', +] + + +def _calculate_soft_iou(m1, m2): + intersection_sum = np.sum(m1 * m2) + union_sum = np.sum(m1 * m1) + np.sum(m2 * m2) - intersection_sum + + if union_sum > 0: + return intersection_sum / union_sum + else: + return 0 + + +def _similar_to_float_mask(actual_mask, expected_mask, similarity_threshold): + actual_mask = actual_mask.numpy_view() + expected_mask = expected_mask.numpy_view() / 255.0 + + return ( + actual_mask.shape == expected_mask.shape + and _calculate_soft_iou(actual_mask, expected_mask) > similarity_threshold + ) def _similar_to_uint8_mask(actual_mask, expected_mask): @@ -56,8 +99,9 @@ def _similar_to_uint8_mask(actual_mask, expected_mask): for index in range(num_pixels): consistent_pixels += ( - actual_mask_pixels[index] * - _MASK_MAGNIFICATION_FACTOR == expected_mask_pixels[index]) + actual_mask_pixels[index] * _MASK_MAGNIFICATION_FACTOR + == expected_mask_pixels[index] + ) return consistent_pixels / num_pixels >= _MASK_SIMILARITY_THRESHOLD @@ -73,16 +117,27 @@ class ImageSegmenterTest(parameterized.TestCase): super().setUp() # Load the test input image. self.test_image = _Image.create_from_file( - test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _IMAGE_FILE))) + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _IMAGE_FILE)) + ) # Loads ground truth segmentation file. gt_segmentation_data = cv2.imread( test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE)), - cv2.IMREAD_GRAYSCALE) + os.path.join(_TEST_DATA_DIR, _SEGMENTATION_FILE) + ), + cv2.IMREAD_GRAYSCALE, + ) self.test_seg_image = _Image(_ImageFormat.GRAY8, gt_segmentation_data) self.model_path = test_utils.get_test_data_path( - os.path.join(_TEST_DATA_DIR, _MODEL_FILE)) + os.path.join(_TEST_DATA_DIR, _MODEL_FILE) + ) + + def _load_segmentation_mask(self, file_path: str): + # Loads ground truth segmentation file. + gt_segmentation_data = cv2.imread( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, file_path)), + cv2.IMREAD_GRAYSCALE, + ) + return _Image(_ImageFormat.GRAY8, gt_segmentation_data) def test_create_from_file_succeeds_with_valid_model_path(self): # Creates with default option and valid model file successfully. @@ -98,9 +153,11 @@ class ImageSegmenterTest(parameterized.TestCase): def test_create_from_options_fails_with_invalid_model_path(self): with self.assertRaisesRegex( - RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite'): + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): base_options = _BaseOptions( - model_asset_path='/path/to/invalid/model.tflite') + model_asset_path='/path/to/invalid/model.tflite' + ) options = _ImageSegmenterOptions(base_options=base_options) _ImageSegmenter.create_from_options(options) @@ -112,8 +169,9 @@ class ImageSegmenterTest(parameterized.TestCase): segmenter = _ImageSegmenter.create_from_options(options) self.assertIsInstance(segmenter, _ImageSegmenter) - @parameterized.parameters((ModelFileType.FILE_NAME,), - (ModelFileType.FILE_CONTENT,)) + @parameterized.parameters( + (ModelFileType.FILE_NAME,), (ModelFileType.FILE_CONTENT,) + ) def test_segment_succeeds_with_category_mask(self, model_file_type): # Creates segmenter. if model_file_type is ModelFileType.FILE_NAME: @@ -127,22 +185,27 @@ class ImageSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) + base_options=base_options, + output_category_mask=True, + output_confidence_masks=False, + ) segmenter = _ImageSegmenter.create_from_options(options) # Performs image segmentation on the input. - category_masks = segmenter.segment(self.test_image) - self.assertLen(category_masks, 1) - category_mask = category_masks[0] + segmentation_result = segmenter.segment(self.test_image) + category_mask = segmentation_result.category_mask result_pixels = category_mask.numpy_view().flatten() # Check if data type of `category_mask` is correct. self.assertEqual(result_pixels.dtype, np.uint8) self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + _similar_to_uint8_mask(category_mask, self.test_seg_image), + ( + 'Number of pixels in the candidate mask differing from that of the' + f' ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.' + ), + ) # Closes the segmenter explicitly when the segmenter is not used in # a context. @@ -152,74 +215,60 @@ class ImageSegmenterTest(parameterized.TestCase): # Creates segmenter. base_options = _BaseOptions(model_asset_path=self.model_path) - # Run segmentation on the model in CATEGORY_MASK mode. - options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) - segmenter = _ImageSegmenter.create_from_options(options) - category_masks = segmenter.segment(self.test_image) - category_mask = category_masks[0].numpy_view() + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)) + ) # Run segmentation on the model in CONFIDENCE_MASK mode. options = _ImageSegmenterOptions( base_options=base_options, - output_type=_OutputType.CONFIDENCE_MASK, - activation=_Activation.SOFTMAX) - segmenter = _ImageSegmenter.create_from_options(options) - confidence_masks = segmenter.segment(self.test_image) + output_category_mask=False, + output_confidence_masks=True, + ) - # Check if confidence mask shape is correct. - self.assertLen( - confidence_masks, 21, - 'Number of confidence masks must match with number of categories.') - - # Gather the confidence masks in a single array `confidence_mask_array`. - confidence_mask_array = np.array( - [confidence_mask.numpy_view() for confidence_mask in confidence_masks]) - - # Check if data type of `confidence_masks` are correct. - self.assertEqual(confidence_mask_array.dtype, np.float32) - - # Compute the category mask from the created confidence mask. - calculated_category_mask = np.argmax(confidence_mask_array, axis=0) - self.assertListEqual( - calculated_category_mask.tolist(), category_mask.tolist(), - 'Confidence mask does not match with the category mask.') - - # Closes the segmenter explicitly when the segmenter is not used in - # a context. - segmenter.close() - - @parameterized.parameters((ModelFileType.FILE_NAME), - (ModelFileType.FILE_CONTENT)) - def test_segment_in_context(self, model_file_type): - if model_file_type is ModelFileType.FILE_NAME: - base_options = _BaseOptions(model_asset_path=self.model_path) - elif model_file_type is ModelFileType.FILE_CONTENT: - with open(self.model_path, 'rb') as f: - model_contents = f.read() - base_options = _BaseOptions(model_asset_buffer=model_contents) - else: - # Should never happen - raise ValueError('model_file_type is invalid.') - - options = _ImageSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK) with _ImageSegmenter.create_from_options(options) as segmenter: - # Performs image segmentation on the input. - category_masks = segmenter.segment(self.test_image) - self.assertLen(category_masks, 1) + segmentation_result = segmenter.segment(test_image) + confidence_masks = segmentation_result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, + 21, + 'Number of confidence masks must match with number of categories.', + ) + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) + + @parameterized.parameters((True, False), (False, True)) + def test_labels_succeeds(self, output_category_mask, output_confidence_masks): + expected_labels = _EXPECTED_LABELS + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _ImageSegmenterOptions( + base_options=base_options, + output_category_mask=output_category_mask, + output_confidence_masks=output_confidence_masks, + ) + with _ImageSegmenter.create_from_options(options) as segmenter: + # Performs image segmentation on the input. + actual_labels = segmenter.labels + self.assertListEqual(actual_labels, expected_labels) def test_missing_result_callback(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.LIVE_STREAM) - with self.assertRaisesRegex(ValueError, - r'result callback must be provided'): + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): with _ImageSegmenter.create_from_options(options) as unused_segmenter: pass @@ -228,130 +277,236 @@ class ImageSegmenterTest(parameterized.TestCase): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=running_mode, - result_callback=mock.MagicMock()) - with self.assertRaisesRegex(ValueError, - r'result callback should not be provided'): + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): with _ImageSegmenter.create_from_options(options) as unused_segmenter: pass def test_calling_segment_for_video_in_image_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): segmenter.segment_for_video(self.test_image, 0) def test_calling_segment_async_in_image_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.IMAGE) + running_mode=_RUNNING_MODE.IMAGE, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): segmenter.segment_async(self.test_image, 0) def test_calling_segment_in_video_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): segmenter.segment(self.test_image) def test_calling_segment_async_in_video_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the live stream mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): segmenter.segment_async(self.test_image, 0) def test_segment_for_video_with_out_of_order_timestamp(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - running_mode=_RUNNING_MODE.VIDEO) + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: unused_result = segmenter.segment_for_video(self.test_image, 1) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): segmenter.segment_for_video(self.test_image, 0) - def test_segment_for_video(self): + def test_segment_for_video_in_category_mask_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - output_type=_OutputType.CATEGORY_MASK, - running_mode=_RUNNING_MODE.VIDEO) + output_category_mask=True, + output_confidence_masks=False, + running_mode=_RUNNING_MODE.VIDEO, + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): - category_masks = segmenter.segment_for_video(self.test_image, timestamp) - self.assertLen(category_masks, 1) + segmentation_result = segmenter.segment_for_video( + self.test_image, timestamp + ) + category_mask = segmentation_result.category_mask self.assertTrue( - _similar_to_uint8_mask(category_masks[0], self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + _similar_to_uint8_mask(category_mask, self.test_seg_image), + ( + 'Number of pixels in the candidate mask differing from that of' + f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.' + ), + ) + + def test_segment_for_video_in_confidence_mask_mode(self): + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)) + ) + + options = _ImageSegmenterOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + output_category_mask=False, + output_confidence_masks=True, + ) + with _ImageSegmenter.create_from_options(options) as segmenter: + for timestamp in range(0, 300, 30): + segmentation_result = segmenter.segment_for_video(test_image, timestamp) + confidence_masks = segmentation_result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, + 21, + 'Number of confidence masks must match with number of categories.', + ) + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) + self.assertTrue( + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) def test_calling_segment_in_live_stream_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the image mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): segmenter.segment(self.test_image) def test_calling_segment_for_video_in_live_stream_mode(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageSegmenter.create_from_options(options) as segmenter: - with self.assertRaisesRegex(ValueError, - r'not initialized with the video mode'): + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): segmenter.segment_for_video(self.test_image, 0) def test_segment_async_calls_with_illegal_timestamp(self): options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=mock.MagicMock()) + result_callback=mock.MagicMock(), + ) with _ImageSegmenter.create_from_options(options) as segmenter: segmenter.segment_async(self.test_image, 100) with self.assertRaisesRegex( - ValueError, r'Input timestamp must be monotonically increasing'): + ValueError, r'Input timestamp must be monotonically increasing' + ): segmenter.segment_async(self.test_image, 0) - def test_segment_async_calls(self): + def test_segment_async_calls_in_category_mask_mode(self): observed_timestamp_ms = -1 - def check_result(result: List[image_module.Image], output_image: _Image, - timestamp_ms: int): + def check_result( + result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int + ): # Get the output category mask. - category_mask = result[0] + category_mask = result.category_mask self.assertEqual(output_image.width, self.test_image.width) self.assertEqual(output_image.height, self.test_image.height) self.assertEqual(output_image.width, self.test_seg_image.width) self.assertEqual(output_image.height, self.test_seg_image.height) self.assertTrue( _similar_to_uint8_mask(category_mask, self.test_seg_image), - f'Number of pixels in the candidate mask differing from that of the ' - f'ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.') + ( + 'Number of pixels in the candidate mask differing from that of' + f' the ground truth mask exceeds {_MASK_SIMILARITY_THRESHOLD}.' + ), + ) self.assertLess(observed_timestamp_ms, timestamp_ms) self.observed_timestamp_ms = timestamp_ms options = _ImageSegmenterOptions( base_options=_BaseOptions(model_asset_path=self.model_path), - output_type=_OutputType.CATEGORY_MASK, + output_category_mask=True, + output_confidence_masks=False, running_mode=_RUNNING_MODE.LIVE_STREAM, - result_callback=check_result) + result_callback=check_result, + ) with _ImageSegmenter.create_from_options(options) as segmenter: for timestamp in range(0, 300, 30): segmenter.segment_async(self.test_image, timestamp) + def test_segment_async_calls_in_confidence_mask_mode(self): + # Load the cat image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)) + ) + + # Loads ground truth segmentation file. + expected_mask = self._load_segmentation_mask(_CAT_MASK) + observed_timestamp_ms = -1 + + def check_result( + result: ImageSegmenterResult, output_image: _Image, timestamp_ms: int + ): + # Get the output category mask. + confidence_masks = result.confidence_masks + + # Check if confidence mask shape is correct. + self.assertLen( + confidence_masks, + 21, + 'Number of confidence masks must match with number of categories.', + ) + self.assertEqual(output_image.width, test_image.width) + self.assertEqual(output_image.height, test_image.height) + self.assertTrue( + _similar_to_float_mask( + confidence_masks[8], expected_mask, _MASK_SIMILARITY_THRESHOLD + ) + ) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + options = _ImageSegmenterOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + output_category_mask=False, + output_confidence_masks=True, + result_callback=check_result, + ) + with _ImageSegmenter.create_from_options(options) as segmenter: + for timestamp in range(0, 300, 30): + segmenter.segment_async(test_image, timestamp) + if __name__ == '__main__': absltest.main() diff --git a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py index e8c52ae3e..2e0039b15 100644 --- a/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py +++ b/mediapipe/tasks/python/test/vision/interactive_segmenter_test.py @@ -30,12 +30,12 @@ from mediapipe.tasks.python.test import test_utils from mediapipe.tasks.python.vision import interactive_segmenter from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +InteractiveSegmenterResult = interactive_segmenter.InteractiveSegmenterResult _BaseOptions = base_options_module.BaseOptions _Image = image_module.Image _ImageFormat = image_frame.ImageFormat _NormalizedKeypoint = keypoint_module.NormalizedKeypoint _Rect = rect.Rect -_OutputType = interactive_segmenter.InteractiveSegmenterOptions.OutputType _InteractiveSegmenter = interactive_segmenter.InteractiveSegmenter _InteractiveSegmenterOptions = interactive_segmenter.InteractiveSegmenterOptions _RegionOfInterest = interactive_segmenter.RegionOfInterest @@ -200,15 +200,16 @@ class InteractiveSegmenterTest(parameterized.TestCase): raise ValueError('model_file_type is invalid.') options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CATEGORY_MASK + base_options=base_options, + output_category_mask=True, + output_confidence_masks=False, ) segmenter = _InteractiveSegmenter.create_from_options(options) # Performs image segmentation on the input. roi = _RegionOfInterest(format=roi_format, keypoint=keypoint) - category_masks = segmenter.segment(self.test_image, roi) - self.assertLen(category_masks, 1) - category_mask = category_masks[0] + segmentation_result = segmenter.segment(self.test_image, roi) + category_mask = segmentation_result.category_mask result_pixels = category_mask.numpy_view().flatten() # Check if data type of `category_mask` is correct. @@ -219,7 +220,7 @@ class InteractiveSegmenterTest(parameterized.TestCase): self.assertTrue( _similar_to_uint8_mask( - category_masks[0], test_seg_image, similarity_threshold + category_mask, test_seg_image, similarity_threshold ), ( 'Number of pixels in the candidate mask differing from that of the' @@ -254,12 +255,15 @@ class InteractiveSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK + base_options=base_options, + output_category_mask=False, + output_confidence_masks=True, ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation - confidence_masks = segmenter.segment(self.test_image, roi) + segmentation_result = segmenter.segment(self.test_image, roi) + confidence_masks = segmentation_result.confidence_masks # Check if confidence mask shape is correct. self.assertLen( @@ -287,15 +291,18 @@ class InteractiveSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK + base_options=base_options, + output_category_mask=False, + output_confidence_masks=True, ) with _InteractiveSegmenter.create_from_options(options) as segmenter: # Perform segmentation image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) - confidence_masks = segmenter.segment( + segmentation_result = segmenter.segment( self.test_image, roi, image_processing_options ) + confidence_masks = segmentation_result.confidence_masks # Check if confidence mask shape is correct. self.assertLen( @@ -314,7 +321,9 @@ class InteractiveSegmenterTest(parameterized.TestCase): # Run segmentation on the model in CONFIDENCE_MASK mode. options = _InteractiveSegmenterOptions( - base_options=base_options, output_type=_OutputType.CONFIDENCE_MASK + base_options=base_options, + output_category_mask=False, + output_confidence_masks=True, ) with self.assertRaisesRegex( diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 046ce2dc8..716757790 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -71,6 +71,7 @@ py_library( "//mediapipe/python:_framework_bindings", "//mediapipe/python:packet_creator", "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_py_pb2", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_py_pb2", "//mediapipe/tasks/python/components/containers:rect", diff --git a/mediapipe/tasks/python/vision/__init__.py b/mediapipe/tasks/python/vision/__init__.py index 49fe03059..53cbf026e 100644 --- a/mediapipe/tasks/python/vision/__init__.py +++ b/mediapipe/tasks/python/vision/__init__.py @@ -32,6 +32,7 @@ FaceDetectorResult = face_detector.FaceDetectorResult FaceLandmarker = face_landmarker.FaceLandmarker FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult +FaceLandmarksConnections = face_landmarker.FaceLandmarksConnections FaceStylizer = face_stylizer.FaceStylizer FaceStylizerOptions = face_stylizer.FaceStylizerOptions GestureRecognizer = gesture_recognizer.GestureRecognizer diff --git a/mediapipe/tasks/python/vision/core/base_vision_task_api.py b/mediapipe/tasks/python/vision/core/base_vision_task_api.py index eb976153e..d9195d3ce 100644 --- a/mediapipe/tasks/python/vision/core/base_vision_task_api.py +++ b/mediapipe/tasks/python/vision/core/base_vision_task_api.py @@ -208,6 +208,11 @@ class BaseVisionTaskApi(object): """ self._runner.close() + def get_graph_config(self) -> calculator_pb2.CalculatorGraphConfig: + """Returns the canonicalized CalculatorGraphConfig of the underlying graph. + """ + return self._runner.get_graph_config() + def __enter__(self): """Return `self` upon entering the runtime context.""" return self diff --git a/mediapipe/tasks/python/vision/face_landmarker.py b/mediapipe/tasks/python/vision/face_landmarker.py index 3e43e8a7f..c5b24499f 100644 --- a/mediapipe/tasks/python/vision/face_landmarker.py +++ b/mediapipe/tasks/python/vision/face_landmarker.py @@ -120,6 +120,2741 @@ class Blendshapes(enum.IntEnum): NOSE_SNEER_RIGHT = 51 +class FaceLandmarksConnections: + """The connections between face landmarks.""" + + @dataclasses.dataclass + class Connection: + """The connection class for face landmarks.""" + + start: int + end: int + + FACE_LANDMARKS_LIPS: List[Connection] = [ + Connection(61, 146), + Connection(146, 91), + Connection(91, 181), + Connection(181, 84), + Connection(84, 17), + Connection(17, 314), + Connection(314, 405), + Connection(405, 321), + Connection(321, 375), + Connection(375, 291), + Connection(61, 185), + Connection(185, 40), + Connection(40, 39), + Connection(39, 37), + Connection(37, 0), + Connection(0, 267), + Connection(267, 269), + Connection(269, 270), + Connection(270, 409), + Connection(409, 291), + Connection(78, 95), + Connection(95, 88), + Connection(88, 178), + Connection(178, 87), + Connection(87, 14), + Connection(14, 317), + Connection(317, 402), + Connection(402, 318), + Connection(318, 324), + Connection(324, 308), + Connection(78, 191), + Connection(191, 80), + Connection(80, 81), + Connection(81, 82), + Connection(82, 13), + Connection(13, 312), + Connection(312, 311), + Connection(311, 310), + Connection(310, 415), + Connection(415, 308), + ] + + FACE_LANDMARKS_LEFT_EYE: List[Connection] = [ + Connection(263, 249), + Connection(249, 390), + Connection(390, 373), + Connection(373, 374), + Connection(374, 380), + Connection(380, 381), + Connection(381, 382), + Connection(382, 362), + Connection(263, 466), + Connection(466, 388), + Connection(388, 387), + Connection(387, 386), + Connection(386, 385), + Connection(385, 384), + Connection(384, 398), + Connection(398, 362), + ] + + FACE_LANDMARKS_LEFT_EYEBROW: List[Connection] = [ + Connection(276, 283), + Connection(283, 282), + Connection(282, 295), + Connection(295, 285), + Connection(300, 293), + Connection(293, 334), + Connection(334, 296), + Connection(296, 336), + ] + + FACE_LANDMARKS_LEFT_IRIS: List[Connection] = [ + Connection(474, 475), + Connection(475, 476), + Connection(476, 477), + Connection(477, 474), + ] + + FACE_LANDMARKS_RIGHT_EYE: List[Connection] = [ + Connection(33, 7), + Connection(7, 163), + Connection(163, 144), + Connection(144, 145), + Connection(145, 153), + Connection(153, 154), + Connection(154, 155), + Connection(155, 133), + Connection(33, 246), + Connection(246, 161), + Connection(161, 160), + Connection(160, 159), + Connection(159, 158), + Connection(158, 157), + Connection(157, 173), + Connection(173, 133), + ] + + FACE_LANDMARKS_RIGHT_EYEBROW: List[Connection] = [ + Connection(46, 53), + Connection(53, 52), + Connection(52, 65), + Connection(65, 55), + Connection(70, 63), + Connection(63, 105), + Connection(105, 66), + Connection(66, 107), + ] + + FACE_LANDMARKS_RIGHT_IRIS: List[Connection] = [ + Connection(469, 470), + Connection(470, 471), + Connection(471, 472), + Connection(472, 469), + ] + + FACE_LANDMARKS_FACE_OVAL: List[Connection] = [ + Connection(10, 338), + Connection(338, 297), + Connection(297, 332), + Connection(332, 284), + Connection(284, 251), + Connection(251, 389), + Connection(389, 356), + Connection(356, 454), + Connection(454, 323), + Connection(323, 361), + Connection(361, 288), + Connection(288, 397), + Connection(397, 365), + Connection(365, 379), + Connection(379, 378), + Connection(378, 400), + Connection(400, 377), + Connection(377, 152), + Connection(152, 148), + Connection(148, 176), + Connection(176, 149), + Connection(149, 150), + Connection(150, 136), + Connection(136, 172), + Connection(172, 58), + Connection(58, 132), + Connection(132, 93), + Connection(93, 234), + Connection(234, 127), + Connection(127, 162), + Connection(162, 21), + Connection(21, 54), + Connection(54, 103), + Connection(103, 67), + Connection(67, 109), + Connection(109, 10), + ] + + FACE_LANDMARKS_CONTOURS: List[Connection] = ( + FACE_LANDMARKS_LIPS + + FACE_LANDMARKS_LEFT_EYE + + FACE_LANDMARKS_LEFT_EYEBROW + + FACE_LANDMARKS_RIGHT_EYE + + FACE_LANDMARKS_RIGHT_EYEBROW + + FACE_LANDMARKS_FACE_OVAL + ) + + FACE_LANDMARKS_TESSELATION: List[Connection] = [ + Connection(127, 34), + Connection(34, 139), + Connection(139, 127), + Connection(11, 0), + Connection(0, 37), + Connection(37, 11), + Connection(232, 231), + Connection(231, 120), + Connection(120, 232), + Connection(72, 37), + Connection(37, 39), + Connection(39, 72), + Connection(128, 121), + Connection(121, 47), + Connection(47, 128), + Connection(232, 121), + Connection(121, 128), + Connection(128, 232), + Connection(104, 69), + Connection(69, 67), + Connection(67, 104), + Connection(175, 171), + Connection(171, 148), + Connection(148, 175), + Connection(118, 50), + Connection(50, 101), + Connection(101, 118), + Connection(73, 39), + Connection(39, 40), + Connection(40, 73), + Connection(9, 151), + Connection(151, 108), + Connection(108, 9), + Connection(48, 115), + Connection(115, 131), + Connection(131, 48), + Connection(194, 204), + Connection(204, 211), + Connection(211, 194), + Connection(74, 40), + Connection(40, 185), + Connection(185, 74), + Connection(80, 42), + Connection(42, 183), + Connection(183, 80), + Connection(40, 92), + Connection(92, 186), + Connection(186, 40), + Connection(230, 229), + Connection(229, 118), + Connection(118, 230), + Connection(202, 212), + Connection(212, 214), + Connection(214, 202), + Connection(83, 18), + Connection(18, 17), + Connection(17, 83), + Connection(76, 61), + Connection(61, 146), + Connection(146, 76), + Connection(160, 29), + Connection(29, 30), + Connection(30, 160), + Connection(56, 157), + Connection(157, 173), + Connection(173, 56), + Connection(106, 204), + Connection(204, 194), + Connection(194, 106), + Connection(135, 214), + Connection(214, 192), + Connection(192, 135), + Connection(203, 165), + Connection(165, 98), + Connection(98, 203), + Connection(21, 71), + Connection(71, 68), + Connection(68, 21), + Connection(51, 45), + Connection(45, 4), + Connection(4, 51), + Connection(144, 24), + Connection(24, 23), + Connection(23, 144), + Connection(77, 146), + Connection(146, 91), + Connection(91, 77), + Connection(205, 50), + Connection(50, 187), + Connection(187, 205), + Connection(201, 200), + Connection(200, 18), + Connection(18, 201), + Connection(91, 106), + Connection(106, 182), + Connection(182, 91), + Connection(90, 91), + Connection(91, 181), + Connection(181, 90), + Connection(85, 84), + Connection(84, 17), + Connection(17, 85), + Connection(206, 203), + Connection(203, 36), + Connection(36, 206), + Connection(148, 171), + Connection(171, 140), + Connection(140, 148), + Connection(92, 40), + Connection(40, 39), + Connection(39, 92), + Connection(193, 189), + Connection(189, 244), + Connection(244, 193), + Connection(159, 158), + Connection(158, 28), + Connection(28, 159), + Connection(247, 246), + Connection(246, 161), + Connection(161, 247), + Connection(236, 3), + Connection(3, 196), + Connection(196, 236), + Connection(54, 68), + Connection(68, 104), + Connection(104, 54), + Connection(193, 168), + Connection(168, 8), + Connection(8, 193), + Connection(117, 228), + Connection(228, 31), + Connection(31, 117), + Connection(189, 193), + Connection(193, 55), + Connection(55, 189), + Connection(98, 97), + Connection(97, 99), + Connection(99, 98), + Connection(126, 47), + Connection(47, 100), + Connection(100, 126), + Connection(166, 79), + Connection(79, 218), + Connection(218, 166), + Connection(155, 154), + Connection(154, 26), + Connection(26, 155), + Connection(209, 49), + Connection(49, 131), + Connection(131, 209), + Connection(135, 136), + Connection(136, 150), + Connection(150, 135), + Connection(47, 126), + Connection(126, 217), + Connection(217, 47), + Connection(223, 52), + Connection(52, 53), + Connection(53, 223), + Connection(45, 51), + Connection(51, 134), + Connection(134, 45), + Connection(211, 170), + Connection(170, 140), + Connection(140, 211), + Connection(67, 69), + Connection(69, 108), + Connection(108, 67), + Connection(43, 106), + Connection(106, 91), + Connection(91, 43), + Connection(230, 119), + Connection(119, 120), + Connection(120, 230), + Connection(226, 130), + Connection(130, 247), + Connection(247, 226), + Connection(63, 53), + Connection(53, 52), + Connection(52, 63), + Connection(238, 20), + Connection(20, 242), + Connection(242, 238), + Connection(46, 70), + Connection(70, 156), + Connection(156, 46), + Connection(78, 62), + Connection(62, 96), + Connection(96, 78), + Connection(46, 53), + Connection(53, 63), + Connection(63, 46), + Connection(143, 34), + Connection(34, 227), + Connection(227, 143), + Connection(123, 117), + Connection(117, 111), + Connection(111, 123), + Connection(44, 125), + Connection(125, 19), + Connection(19, 44), + Connection(236, 134), + Connection(134, 51), + Connection(51, 236), + Connection(216, 206), + Connection(206, 205), + Connection(205, 216), + Connection(154, 153), + Connection(153, 22), + Connection(22, 154), + Connection(39, 37), + Connection(37, 167), + Connection(167, 39), + Connection(200, 201), + Connection(201, 208), + Connection(208, 200), + Connection(36, 142), + Connection(142, 100), + Connection(100, 36), + Connection(57, 212), + Connection(212, 202), + Connection(202, 57), + Connection(20, 60), + Connection(60, 99), + Connection(99, 20), + Connection(28, 158), + Connection(158, 157), + Connection(157, 28), + Connection(35, 226), + Connection(226, 113), + Connection(113, 35), + Connection(160, 159), + Connection(159, 27), + Connection(27, 160), + Connection(204, 202), + Connection(202, 210), + Connection(210, 204), + Connection(113, 225), + Connection(225, 46), + Connection(46, 113), + Connection(43, 202), + Connection(202, 204), + Connection(204, 43), + Connection(62, 76), + Connection(76, 77), + Connection(77, 62), + Connection(137, 123), + Connection(123, 116), + Connection(116, 137), + Connection(41, 38), + Connection(38, 72), + Connection(72, 41), + Connection(203, 129), + Connection(129, 142), + Connection(142, 203), + Connection(64, 98), + Connection(98, 240), + Connection(240, 64), + Connection(49, 102), + Connection(102, 64), + Connection(64, 49), + Connection(41, 73), + Connection(73, 74), + Connection(74, 41), + Connection(212, 216), + Connection(216, 207), + Connection(207, 212), + Connection(42, 74), + Connection(74, 184), + Connection(184, 42), + Connection(169, 170), + Connection(170, 211), + Connection(211, 169), + Connection(170, 149), + Connection(149, 176), + Connection(176, 170), + Connection(105, 66), + Connection(66, 69), + Connection(69, 105), + Connection(122, 6), + Connection(6, 168), + Connection(168, 122), + Connection(123, 147), + Connection(147, 187), + Connection(187, 123), + Connection(96, 77), + Connection(77, 90), + Connection(90, 96), + Connection(65, 55), + Connection(55, 107), + Connection(107, 65), + Connection(89, 90), + Connection(90, 180), + Connection(180, 89), + Connection(101, 100), + Connection(100, 120), + Connection(120, 101), + Connection(63, 105), + Connection(105, 104), + Connection(104, 63), + Connection(93, 137), + Connection(137, 227), + Connection(227, 93), + Connection(15, 86), + Connection(86, 85), + Connection(85, 15), + Connection(129, 102), + Connection(102, 49), + Connection(49, 129), + Connection(14, 87), + Connection(87, 86), + Connection(86, 14), + Connection(55, 8), + Connection(8, 9), + Connection(9, 55), + Connection(100, 47), + Connection(47, 121), + Connection(121, 100), + Connection(145, 23), + Connection(23, 22), + Connection(22, 145), + Connection(88, 89), + Connection(89, 179), + Connection(179, 88), + Connection(6, 122), + Connection(122, 196), + Connection(196, 6), + Connection(88, 95), + Connection(95, 96), + Connection(96, 88), + Connection(138, 172), + Connection(172, 136), + Connection(136, 138), + Connection(215, 58), + Connection(58, 172), + Connection(172, 215), + Connection(115, 48), + Connection(48, 219), + Connection(219, 115), + Connection(42, 80), + Connection(80, 81), + Connection(81, 42), + Connection(195, 3), + Connection(3, 51), + Connection(51, 195), + Connection(43, 146), + Connection(146, 61), + Connection(61, 43), + Connection(171, 175), + Connection(175, 199), + Connection(199, 171), + Connection(81, 82), + Connection(82, 38), + Connection(38, 81), + Connection(53, 46), + Connection(46, 225), + Connection(225, 53), + Connection(144, 163), + Connection(163, 110), + Connection(110, 144), + Connection(52, 65), + Connection(65, 66), + Connection(66, 52), + Connection(229, 228), + Connection(228, 117), + Connection(117, 229), + Connection(34, 127), + Connection(127, 234), + Connection(234, 34), + Connection(107, 108), + Connection(108, 69), + Connection(69, 107), + Connection(109, 108), + Connection(108, 151), + Connection(151, 109), + Connection(48, 64), + Connection(64, 235), + Connection(235, 48), + Connection(62, 78), + Connection(78, 191), + Connection(191, 62), + Connection(129, 209), + Connection(209, 126), + Connection(126, 129), + Connection(111, 35), + Connection(35, 143), + Connection(143, 111), + Connection(117, 123), + Connection(123, 50), + Connection(50, 117), + Connection(222, 65), + Connection(65, 52), + Connection(52, 222), + Connection(19, 125), + Connection(125, 141), + Connection(141, 19), + Connection(221, 55), + Connection(55, 65), + Connection(65, 221), + Connection(3, 195), + Connection(195, 197), + Connection(197, 3), + Connection(25, 7), + Connection(7, 33), + Connection(33, 25), + Connection(220, 237), + Connection(237, 44), + Connection(44, 220), + Connection(70, 71), + Connection(71, 139), + Connection(139, 70), + Connection(122, 193), + Connection(193, 245), + Connection(245, 122), + Connection(247, 130), + Connection(130, 33), + Connection(33, 247), + Connection(71, 21), + Connection(21, 162), + Connection(162, 71), + Connection(170, 169), + Connection(169, 150), + Connection(150, 170), + Connection(188, 174), + Connection(174, 196), + Connection(196, 188), + Connection(216, 186), + Connection(186, 92), + Connection(92, 216), + Connection(2, 97), + Connection(97, 167), + Connection(167, 2), + Connection(141, 125), + Connection(125, 241), + Connection(241, 141), + Connection(164, 167), + Connection(167, 37), + Connection(37, 164), + Connection(72, 38), + Connection(38, 12), + Connection(12, 72), + Connection(38, 82), + Connection(82, 13), + Connection(13, 38), + Connection(63, 68), + Connection(68, 71), + Connection(71, 63), + Connection(226, 35), + Connection(35, 111), + Connection(111, 226), + Connection(101, 50), + Connection(50, 205), + Connection(205, 101), + Connection(206, 92), + Connection(92, 165), + Connection(165, 206), + Connection(209, 198), + Connection(198, 217), + Connection(217, 209), + Connection(165, 167), + Connection(167, 97), + Connection(97, 165), + Connection(220, 115), + Connection(115, 218), + Connection(218, 220), + Connection(133, 112), + Connection(112, 243), + Connection(243, 133), + Connection(239, 238), + Connection(238, 241), + Connection(241, 239), + Connection(214, 135), + Connection(135, 169), + Connection(169, 214), + Connection(190, 173), + Connection(173, 133), + Connection(133, 190), + Connection(171, 208), + Connection(208, 32), + Connection(32, 171), + Connection(125, 44), + Connection(44, 237), + Connection(237, 125), + Connection(86, 87), + Connection(87, 178), + Connection(178, 86), + Connection(85, 86), + Connection(86, 179), + Connection(179, 85), + Connection(84, 85), + Connection(85, 180), + Connection(180, 84), + Connection(83, 84), + Connection(84, 181), + Connection(181, 83), + Connection(201, 83), + Connection(83, 182), + Connection(182, 201), + Connection(137, 93), + Connection(93, 132), + Connection(132, 137), + Connection(76, 62), + Connection(62, 183), + Connection(183, 76), + Connection(61, 76), + Connection(76, 184), + Connection(184, 61), + Connection(57, 61), + Connection(61, 185), + Connection(185, 57), + Connection(212, 57), + Connection(57, 186), + Connection(186, 212), + Connection(214, 207), + Connection(207, 187), + Connection(187, 214), + Connection(34, 143), + Connection(143, 156), + Connection(156, 34), + Connection(79, 239), + Connection(239, 237), + Connection(237, 79), + Connection(123, 137), + Connection(137, 177), + Connection(177, 123), + Connection(44, 1), + Connection(1, 4), + Connection(4, 44), + Connection(201, 194), + Connection(194, 32), + Connection(32, 201), + Connection(64, 102), + Connection(102, 129), + Connection(129, 64), + Connection(213, 215), + Connection(215, 138), + Connection(138, 213), + Connection(59, 166), + Connection(166, 219), + Connection(219, 59), + Connection(242, 99), + Connection(99, 97), + Connection(97, 242), + Connection(2, 94), + Connection(94, 141), + Connection(141, 2), + Connection(75, 59), + Connection(59, 235), + Connection(235, 75), + Connection(24, 110), + Connection(110, 228), + Connection(228, 24), + Connection(25, 130), + Connection(130, 226), + Connection(226, 25), + Connection(23, 24), + Connection(24, 229), + Connection(229, 23), + Connection(22, 23), + Connection(23, 230), + Connection(230, 22), + Connection(26, 22), + Connection(22, 231), + Connection(231, 26), + Connection(112, 26), + Connection(26, 232), + Connection(232, 112), + Connection(189, 190), + Connection(190, 243), + Connection(243, 189), + Connection(221, 56), + Connection(56, 190), + Connection(190, 221), + Connection(28, 56), + Connection(56, 221), + Connection(221, 28), + Connection(27, 28), + Connection(28, 222), + Connection(222, 27), + Connection(29, 27), + Connection(27, 223), + Connection(223, 29), + Connection(30, 29), + Connection(29, 224), + Connection(224, 30), + Connection(247, 30), + Connection(30, 225), + Connection(225, 247), + Connection(238, 79), + Connection(79, 20), + Connection(20, 238), + Connection(166, 59), + Connection(59, 75), + Connection(75, 166), + Connection(60, 75), + Connection(75, 240), + Connection(240, 60), + Connection(147, 177), + Connection(177, 215), + Connection(215, 147), + Connection(20, 79), + Connection(79, 166), + Connection(166, 20), + Connection(187, 147), + Connection(147, 213), + Connection(213, 187), + Connection(112, 233), + Connection(233, 244), + Connection(244, 112), + Connection(233, 128), + Connection(128, 245), + Connection(245, 233), + Connection(128, 114), + Connection(114, 188), + Connection(188, 128), + Connection(114, 217), + Connection(217, 174), + Connection(174, 114), + Connection(131, 115), + Connection(115, 220), + Connection(220, 131), + Connection(217, 198), + Connection(198, 236), + Connection(236, 217), + Connection(198, 131), + Connection(131, 134), + Connection(134, 198), + Connection(177, 132), + Connection(132, 58), + Connection(58, 177), + Connection(143, 35), + Connection(35, 124), + Connection(124, 143), + Connection(110, 163), + Connection(163, 7), + Connection(7, 110), + Connection(228, 110), + Connection(110, 25), + Connection(25, 228), + Connection(356, 389), + Connection(389, 368), + Connection(368, 356), + Connection(11, 302), + Connection(302, 267), + Connection(267, 11), + Connection(452, 350), + Connection(350, 349), + Connection(349, 452), + Connection(302, 303), + Connection(303, 269), + Connection(269, 302), + Connection(357, 343), + Connection(343, 277), + Connection(277, 357), + Connection(452, 453), + Connection(453, 357), + Connection(357, 452), + Connection(333, 332), + Connection(332, 297), + Connection(297, 333), + Connection(175, 152), + Connection(152, 377), + Connection(377, 175), + Connection(347, 348), + Connection(348, 330), + Connection(330, 347), + Connection(303, 304), + Connection(304, 270), + Connection(270, 303), + Connection(9, 336), + Connection(336, 337), + Connection(337, 9), + Connection(278, 279), + Connection(279, 360), + Connection(360, 278), + Connection(418, 262), + Connection(262, 431), + Connection(431, 418), + Connection(304, 408), + Connection(408, 409), + Connection(409, 304), + Connection(310, 415), + Connection(415, 407), + Connection(407, 310), + Connection(270, 409), + Connection(409, 410), + Connection(410, 270), + Connection(450, 348), + Connection(348, 347), + Connection(347, 450), + Connection(422, 430), + Connection(430, 434), + Connection(434, 422), + Connection(313, 314), + Connection(314, 17), + Connection(17, 313), + Connection(306, 307), + Connection(307, 375), + Connection(375, 306), + Connection(387, 388), + Connection(388, 260), + Connection(260, 387), + Connection(286, 414), + Connection(414, 398), + Connection(398, 286), + Connection(335, 406), + Connection(406, 418), + Connection(418, 335), + Connection(364, 367), + Connection(367, 416), + Connection(416, 364), + Connection(423, 358), + Connection(358, 327), + Connection(327, 423), + Connection(251, 284), + Connection(284, 298), + Connection(298, 251), + Connection(281, 5), + Connection(5, 4), + Connection(4, 281), + Connection(373, 374), + Connection(374, 253), + Connection(253, 373), + Connection(307, 320), + Connection(320, 321), + Connection(321, 307), + Connection(425, 427), + Connection(427, 411), + Connection(411, 425), + Connection(421, 313), + Connection(313, 18), + Connection(18, 421), + Connection(321, 405), + Connection(405, 406), + Connection(406, 321), + Connection(320, 404), + Connection(404, 405), + Connection(405, 320), + Connection(315, 16), + Connection(16, 17), + Connection(17, 315), + Connection(426, 425), + Connection(425, 266), + Connection(266, 426), + Connection(377, 400), + Connection(400, 369), + Connection(369, 377), + Connection(322, 391), + Connection(391, 269), + Connection(269, 322), + Connection(417, 465), + Connection(465, 464), + Connection(464, 417), + Connection(386, 257), + Connection(257, 258), + Connection(258, 386), + Connection(466, 260), + Connection(260, 388), + Connection(388, 466), + Connection(456, 399), + Connection(399, 419), + Connection(419, 456), + Connection(284, 332), + Connection(332, 333), + Connection(333, 284), + Connection(417, 285), + Connection(285, 8), + Connection(8, 417), + Connection(346, 340), + Connection(340, 261), + Connection(261, 346), + Connection(413, 441), + Connection(441, 285), + Connection(285, 413), + Connection(327, 460), + Connection(460, 328), + Connection(328, 327), + Connection(355, 371), + Connection(371, 329), + Connection(329, 355), + Connection(392, 439), + Connection(439, 438), + Connection(438, 392), + Connection(382, 341), + Connection(341, 256), + Connection(256, 382), + Connection(429, 420), + Connection(420, 360), + Connection(360, 429), + Connection(364, 394), + Connection(394, 379), + Connection(379, 364), + Connection(277, 343), + Connection(343, 437), + Connection(437, 277), + Connection(443, 444), + Connection(444, 283), + Connection(283, 443), + Connection(275, 440), + Connection(440, 363), + Connection(363, 275), + Connection(431, 262), + Connection(262, 369), + Connection(369, 431), + Connection(297, 338), + Connection(338, 337), + Connection(337, 297), + Connection(273, 375), + Connection(375, 321), + Connection(321, 273), + Connection(450, 451), + Connection(451, 349), + Connection(349, 450), + Connection(446, 342), + Connection(342, 467), + Connection(467, 446), + Connection(293, 334), + Connection(334, 282), + Connection(282, 293), + Connection(458, 461), + Connection(461, 462), + Connection(462, 458), + Connection(276, 353), + Connection(353, 383), + Connection(383, 276), + Connection(308, 324), + Connection(324, 325), + Connection(325, 308), + Connection(276, 300), + Connection(300, 293), + Connection(293, 276), + Connection(372, 345), + Connection(345, 447), + Connection(447, 372), + Connection(352, 345), + Connection(345, 340), + Connection(340, 352), + Connection(274, 1), + Connection(1, 19), + Connection(19, 274), + Connection(456, 248), + Connection(248, 281), + Connection(281, 456), + Connection(436, 427), + Connection(427, 425), + Connection(425, 436), + Connection(381, 256), + Connection(256, 252), + Connection(252, 381), + Connection(269, 391), + Connection(391, 393), + Connection(393, 269), + Connection(200, 199), + Connection(199, 428), + Connection(428, 200), + Connection(266, 330), + Connection(330, 329), + Connection(329, 266), + Connection(287, 273), + Connection(273, 422), + Connection(422, 287), + Connection(250, 462), + Connection(462, 328), + Connection(328, 250), + Connection(258, 286), + Connection(286, 384), + Connection(384, 258), + Connection(265, 353), + Connection(353, 342), + Connection(342, 265), + Connection(387, 259), + Connection(259, 257), + Connection(257, 387), + Connection(424, 431), + Connection(431, 430), + Connection(430, 424), + Connection(342, 353), + Connection(353, 276), + Connection(276, 342), + Connection(273, 335), + Connection(335, 424), + Connection(424, 273), + Connection(292, 325), + Connection(325, 307), + Connection(307, 292), + Connection(366, 447), + Connection(447, 345), + Connection(345, 366), + Connection(271, 303), + Connection(303, 302), + Connection(302, 271), + Connection(423, 266), + Connection(266, 371), + Connection(371, 423), + Connection(294, 455), + Connection(455, 460), + Connection(460, 294), + Connection(279, 278), + Connection(278, 294), + Connection(294, 279), + Connection(271, 272), + Connection(272, 304), + Connection(304, 271), + Connection(432, 434), + Connection(434, 427), + Connection(427, 432), + Connection(272, 407), + Connection(407, 408), + Connection(408, 272), + Connection(394, 430), + Connection(430, 431), + Connection(431, 394), + Connection(395, 369), + Connection(369, 400), + Connection(400, 395), + Connection(334, 333), + Connection(333, 299), + Connection(299, 334), + Connection(351, 417), + Connection(417, 168), + Connection(168, 351), + Connection(352, 280), + Connection(280, 411), + Connection(411, 352), + Connection(325, 319), + Connection(319, 320), + Connection(320, 325), + Connection(295, 296), + Connection(296, 336), + Connection(336, 295), + Connection(319, 403), + Connection(403, 404), + Connection(404, 319), + Connection(330, 348), + Connection(348, 349), + Connection(349, 330), + Connection(293, 298), + Connection(298, 333), + Connection(333, 293), + Connection(323, 454), + Connection(454, 447), + Connection(447, 323), + Connection(15, 16), + Connection(16, 315), + Connection(315, 15), + Connection(358, 429), + Connection(429, 279), + Connection(279, 358), + Connection(14, 15), + Connection(15, 316), + Connection(316, 14), + Connection(285, 336), + Connection(336, 9), + Connection(9, 285), + Connection(329, 349), + Connection(349, 350), + Connection(350, 329), + Connection(374, 380), + Connection(380, 252), + Connection(252, 374), + Connection(318, 402), + Connection(402, 403), + Connection(403, 318), + Connection(6, 197), + Connection(197, 419), + Connection(419, 6), + Connection(318, 319), + Connection(319, 325), + Connection(325, 318), + Connection(367, 364), + Connection(364, 365), + Connection(365, 367), + Connection(435, 367), + Connection(367, 397), + Connection(397, 435), + Connection(344, 438), + Connection(438, 439), + Connection(439, 344), + Connection(272, 271), + Connection(271, 311), + Connection(311, 272), + Connection(195, 5), + Connection(5, 281), + Connection(281, 195), + Connection(273, 287), + Connection(287, 291), + Connection(291, 273), + Connection(396, 428), + Connection(428, 199), + Connection(199, 396), + Connection(311, 271), + Connection(271, 268), + Connection(268, 311), + Connection(283, 444), + Connection(444, 445), + Connection(445, 283), + Connection(373, 254), + Connection(254, 339), + Connection(339, 373), + Connection(282, 334), + Connection(334, 296), + Connection(296, 282), + Connection(449, 347), + Connection(347, 346), + Connection(346, 449), + Connection(264, 447), + Connection(447, 454), + Connection(454, 264), + Connection(336, 296), + Connection(296, 299), + Connection(299, 336), + Connection(338, 10), + Connection(10, 151), + Connection(151, 338), + Connection(278, 439), + Connection(439, 455), + Connection(455, 278), + Connection(292, 407), + Connection(407, 415), + Connection(415, 292), + Connection(358, 371), + Connection(371, 355), + Connection(355, 358), + Connection(340, 345), + Connection(345, 372), + Connection(372, 340), + Connection(346, 347), + Connection(347, 280), + Connection(280, 346), + Connection(442, 443), + Connection(443, 282), + Connection(282, 442), + Connection(19, 94), + Connection(94, 370), + Connection(370, 19), + Connection(441, 442), + Connection(442, 295), + Connection(295, 441), + Connection(248, 419), + Connection(419, 197), + Connection(197, 248), + Connection(263, 255), + Connection(255, 359), + Connection(359, 263), + Connection(440, 275), + Connection(275, 274), + Connection(274, 440), + Connection(300, 383), + Connection(383, 368), + Connection(368, 300), + Connection(351, 412), + Connection(412, 465), + Connection(465, 351), + Connection(263, 467), + Connection(467, 466), + Connection(466, 263), + Connection(301, 368), + Connection(368, 389), + Connection(389, 301), + Connection(395, 378), + Connection(378, 379), + Connection(379, 395), + Connection(412, 351), + Connection(351, 419), + Connection(419, 412), + Connection(436, 426), + Connection(426, 322), + Connection(322, 436), + Connection(2, 164), + Connection(164, 393), + Connection(393, 2), + Connection(370, 462), + Connection(462, 461), + Connection(461, 370), + Connection(164, 0), + Connection(0, 267), + Connection(267, 164), + Connection(302, 11), + Connection(11, 12), + Connection(12, 302), + Connection(268, 12), + Connection(12, 13), + Connection(13, 268), + Connection(293, 300), + Connection(300, 301), + Connection(301, 293), + Connection(446, 261), + Connection(261, 340), + Connection(340, 446), + Connection(330, 266), + Connection(266, 425), + Connection(425, 330), + Connection(426, 423), + Connection(423, 391), + Connection(391, 426), + Connection(429, 355), + Connection(355, 437), + Connection(437, 429), + Connection(391, 327), + Connection(327, 326), + Connection(326, 391), + Connection(440, 457), + Connection(457, 438), + Connection(438, 440), + Connection(341, 382), + Connection(382, 362), + Connection(362, 341), + Connection(459, 457), + Connection(457, 461), + Connection(461, 459), + Connection(434, 430), + Connection(430, 394), + Connection(394, 434), + Connection(414, 463), + Connection(463, 362), + Connection(362, 414), + Connection(396, 369), + Connection(369, 262), + Connection(262, 396), + Connection(354, 461), + Connection(461, 457), + Connection(457, 354), + Connection(316, 403), + Connection(403, 402), + Connection(402, 316), + Connection(315, 404), + Connection(404, 403), + Connection(403, 315), + Connection(314, 405), + Connection(405, 404), + Connection(404, 314), + Connection(313, 406), + Connection(406, 405), + Connection(405, 313), + Connection(421, 418), + Connection(418, 406), + Connection(406, 421), + Connection(366, 401), + Connection(401, 361), + Connection(361, 366), + Connection(306, 408), + Connection(408, 407), + Connection(407, 306), + Connection(291, 409), + Connection(409, 408), + Connection(408, 291), + Connection(287, 410), + Connection(410, 409), + Connection(409, 287), + Connection(432, 436), + Connection(436, 410), + Connection(410, 432), + Connection(434, 416), + Connection(416, 411), + Connection(411, 434), + Connection(264, 368), + Connection(368, 383), + Connection(383, 264), + Connection(309, 438), + Connection(438, 457), + Connection(457, 309), + Connection(352, 376), + Connection(376, 401), + Connection(401, 352), + Connection(274, 275), + Connection(275, 4), + Connection(4, 274), + Connection(421, 428), + Connection(428, 262), + Connection(262, 421), + Connection(294, 327), + Connection(327, 358), + Connection(358, 294), + Connection(433, 416), + Connection(416, 367), + Connection(367, 433), + Connection(289, 455), + Connection(455, 439), + Connection(439, 289), + Connection(462, 370), + Connection(370, 326), + Connection(326, 462), + Connection(2, 326), + Connection(326, 370), + Connection(370, 2), + Connection(305, 460), + Connection(460, 455), + Connection(455, 305), + Connection(254, 449), + Connection(449, 448), + Connection(448, 254), + Connection(255, 261), + Connection(261, 446), + Connection(446, 255), + Connection(253, 450), + Connection(450, 449), + Connection(449, 253), + Connection(252, 451), + Connection(451, 450), + Connection(450, 252), + Connection(256, 452), + Connection(452, 451), + Connection(451, 256), + Connection(341, 453), + Connection(453, 452), + Connection(452, 341), + Connection(413, 464), + Connection(464, 463), + Connection(463, 413), + Connection(441, 413), + Connection(413, 414), + Connection(414, 441), + Connection(258, 442), + Connection(442, 441), + Connection(441, 258), + Connection(257, 443), + Connection(443, 442), + Connection(442, 257), + Connection(259, 444), + Connection(444, 443), + Connection(443, 259), + Connection(260, 445), + Connection(445, 444), + Connection(444, 260), + Connection(467, 342), + Connection(342, 445), + Connection(445, 467), + Connection(459, 458), + Connection(458, 250), + Connection(250, 459), + Connection(289, 392), + Connection(392, 290), + Connection(290, 289), + Connection(290, 328), + Connection(328, 460), + Connection(460, 290), + Connection(376, 433), + Connection(433, 435), + Connection(435, 376), + Connection(250, 290), + Connection(290, 392), + Connection(392, 250), + Connection(411, 416), + Connection(416, 433), + Connection(433, 411), + Connection(341, 463), + Connection(463, 464), + Connection(464, 341), + Connection(453, 464), + Connection(464, 465), + Connection(465, 453), + Connection(357, 465), + Connection(465, 412), + Connection(412, 357), + Connection(343, 412), + Connection(412, 399), + Connection(399, 343), + Connection(360, 363), + Connection(363, 440), + Connection(440, 360), + Connection(437, 399), + Connection(399, 456), + Connection(456, 437), + Connection(420, 456), + Connection(456, 363), + Connection(363, 420), + Connection(401, 435), + Connection(435, 288), + Connection(288, 401), + Connection(372, 383), + Connection(383, 353), + Connection(353, 372), + Connection(339, 255), + Connection(255, 249), + Connection(249, 339), + Connection(448, 261), + Connection(261, 255), + Connection(255, 448), + Connection(133, 243), + Connection(243, 190), + Connection(190, 133), + Connection(133, 155), + Connection(155, 112), + Connection(112, 133), + Connection(33, 246), + Connection(246, 247), + Connection(247, 33), + Connection(33, 130), + Connection(130, 25), + Connection(25, 33), + Connection(398, 384), + Connection(384, 286), + Connection(286, 398), + Connection(362, 398), + Connection(398, 414), + Connection(414, 362), + Connection(362, 463), + Connection(463, 341), + Connection(341, 362), + Connection(263, 359), + Connection(359, 467), + Connection(467, 263), + Connection(263, 249), + Connection(249, 255), + Connection(255, 263), + Connection(466, 467), + Connection(467, 260), + Connection(260, 466), + Connection(75, 60), + Connection(60, 166), + Connection(166, 75), + Connection(238, 239), + Connection(239, 79), + Connection(79, 238), + Connection(162, 127), + Connection(127, 139), + Connection(139, 162), + Connection(72, 11), + Connection(11, 37), + Connection(37, 72), + Connection(121, 232), + Connection(232, 120), + Connection(120, 121), + Connection(73, 72), + Connection(72, 39), + Connection(39, 73), + Connection(114, 128), + Connection(128, 47), + Connection(47, 114), + Connection(233, 232), + Connection(232, 128), + Connection(128, 233), + Connection(103, 104), + Connection(104, 67), + Connection(67, 103), + Connection(152, 175), + Connection(175, 148), + Connection(148, 152), + Connection(119, 118), + Connection(118, 101), + Connection(101, 119), + Connection(74, 73), + Connection(73, 40), + Connection(40, 74), + Connection(107, 9), + Connection(9, 108), + Connection(108, 107), + Connection(49, 48), + Connection(48, 131), + Connection(131, 49), + Connection(32, 194), + Connection(194, 211), + Connection(211, 32), + Connection(184, 74), + Connection(74, 185), + Connection(185, 184), + Connection(191, 80), + Connection(80, 183), + Connection(183, 191), + Connection(185, 40), + Connection(40, 186), + Connection(186, 185), + Connection(119, 230), + Connection(230, 118), + Connection(118, 119), + Connection(210, 202), + Connection(202, 214), + Connection(214, 210), + Connection(84, 83), + Connection(83, 17), + Connection(17, 84), + Connection(77, 76), + Connection(76, 146), + Connection(146, 77), + Connection(161, 160), + Connection(160, 30), + Connection(30, 161), + Connection(190, 56), + Connection(56, 173), + Connection(173, 190), + Connection(182, 106), + Connection(106, 194), + Connection(194, 182), + Connection(138, 135), + Connection(135, 192), + Connection(192, 138), + Connection(129, 203), + Connection(203, 98), + Connection(98, 129), + Connection(54, 21), + Connection(21, 68), + Connection(68, 54), + Connection(5, 51), + Connection(51, 4), + Connection(4, 5), + Connection(145, 144), + Connection(144, 23), + Connection(23, 145), + Connection(90, 77), + Connection(77, 91), + Connection(91, 90), + Connection(207, 205), + Connection(205, 187), + Connection(187, 207), + Connection(83, 201), + Connection(201, 18), + Connection(18, 83), + Connection(181, 91), + Connection(91, 182), + Connection(182, 181), + Connection(180, 90), + Connection(90, 181), + Connection(181, 180), + Connection(16, 85), + Connection(85, 17), + Connection(17, 16), + Connection(205, 206), + Connection(206, 36), + Connection(36, 205), + Connection(176, 148), + Connection(148, 140), + Connection(140, 176), + Connection(165, 92), + Connection(92, 39), + Connection(39, 165), + Connection(245, 193), + Connection(193, 244), + Connection(244, 245), + Connection(27, 159), + Connection(159, 28), + Connection(28, 27), + Connection(30, 247), + Connection(247, 161), + Connection(161, 30), + Connection(174, 236), + Connection(236, 196), + Connection(196, 174), + Connection(103, 54), + Connection(54, 104), + Connection(104, 103), + Connection(55, 193), + Connection(193, 8), + Connection(8, 55), + Connection(111, 117), + Connection(117, 31), + Connection(31, 111), + Connection(221, 189), + Connection(189, 55), + Connection(55, 221), + Connection(240, 98), + Connection(98, 99), + Connection(99, 240), + Connection(142, 126), + Connection(126, 100), + Connection(100, 142), + Connection(219, 166), + Connection(166, 218), + Connection(218, 219), + Connection(112, 155), + Connection(155, 26), + Connection(26, 112), + Connection(198, 209), + Connection(209, 131), + Connection(131, 198), + Connection(169, 135), + Connection(135, 150), + Connection(150, 169), + Connection(114, 47), + Connection(47, 217), + Connection(217, 114), + Connection(224, 223), + Connection(223, 53), + Connection(53, 224), + Connection(220, 45), + Connection(45, 134), + Connection(134, 220), + Connection(32, 211), + Connection(211, 140), + Connection(140, 32), + Connection(109, 67), + Connection(67, 108), + Connection(108, 109), + Connection(146, 43), + Connection(43, 91), + Connection(91, 146), + Connection(231, 230), + Connection(230, 120), + Connection(120, 231), + Connection(113, 226), + Connection(226, 247), + Connection(247, 113), + Connection(105, 63), + Connection(63, 52), + Connection(52, 105), + Connection(241, 238), + Connection(238, 242), + Connection(242, 241), + Connection(124, 46), + Connection(46, 156), + Connection(156, 124), + Connection(95, 78), + Connection(78, 96), + Connection(96, 95), + Connection(70, 46), + Connection(46, 63), + Connection(63, 70), + Connection(116, 143), + Connection(143, 227), + Connection(227, 116), + Connection(116, 123), + Connection(123, 111), + Connection(111, 116), + Connection(1, 44), + Connection(44, 19), + Connection(19, 1), + Connection(3, 236), + Connection(236, 51), + Connection(51, 3), + Connection(207, 216), + Connection(216, 205), + Connection(205, 207), + Connection(26, 154), + Connection(154, 22), + Connection(22, 26), + Connection(165, 39), + Connection(39, 167), + Connection(167, 165), + Connection(199, 200), + Connection(200, 208), + Connection(208, 199), + Connection(101, 36), + Connection(36, 100), + Connection(100, 101), + Connection(43, 57), + Connection(57, 202), + Connection(202, 43), + Connection(242, 20), + Connection(20, 99), + Connection(99, 242), + Connection(56, 28), + Connection(28, 157), + Connection(157, 56), + Connection(124, 35), + Connection(35, 113), + Connection(113, 124), + Connection(29, 160), + Connection(160, 27), + Connection(27, 29), + Connection(211, 204), + Connection(204, 210), + Connection(210, 211), + Connection(124, 113), + Connection(113, 46), + Connection(46, 124), + Connection(106, 43), + Connection(43, 204), + Connection(204, 106), + Connection(96, 62), + Connection(62, 77), + Connection(77, 96), + Connection(227, 137), + Connection(137, 116), + Connection(116, 227), + Connection(73, 41), + Connection(41, 72), + Connection(72, 73), + Connection(36, 203), + Connection(203, 142), + Connection(142, 36), + Connection(235, 64), + Connection(64, 240), + Connection(240, 235), + Connection(48, 49), + Connection(49, 64), + Connection(64, 48), + Connection(42, 41), + Connection(41, 74), + Connection(74, 42), + Connection(214, 212), + Connection(212, 207), + Connection(207, 214), + Connection(183, 42), + Connection(42, 184), + Connection(184, 183), + Connection(210, 169), + Connection(169, 211), + Connection(211, 210), + Connection(140, 170), + Connection(170, 176), + Connection(176, 140), + Connection(104, 105), + Connection(105, 69), + Connection(69, 104), + Connection(193, 122), + Connection(122, 168), + Connection(168, 193), + Connection(50, 123), + Connection(123, 187), + Connection(187, 50), + Connection(89, 96), + Connection(96, 90), + Connection(90, 89), + Connection(66, 65), + Connection(65, 107), + Connection(107, 66), + Connection(179, 89), + Connection(89, 180), + Connection(180, 179), + Connection(119, 101), + Connection(101, 120), + Connection(120, 119), + Connection(68, 63), + Connection(63, 104), + Connection(104, 68), + Connection(234, 93), + Connection(93, 227), + Connection(227, 234), + Connection(16, 15), + Connection(15, 85), + Connection(85, 16), + Connection(209, 129), + Connection(129, 49), + Connection(49, 209), + Connection(15, 14), + Connection(14, 86), + Connection(86, 15), + Connection(107, 55), + Connection(55, 9), + Connection(9, 107), + Connection(120, 100), + Connection(100, 121), + Connection(121, 120), + Connection(153, 145), + Connection(145, 22), + Connection(22, 153), + Connection(178, 88), + Connection(88, 179), + Connection(179, 178), + Connection(197, 6), + Connection(6, 196), + Connection(196, 197), + Connection(89, 88), + Connection(88, 96), + Connection(96, 89), + Connection(135, 138), + Connection(138, 136), + Connection(136, 135), + Connection(138, 215), + Connection(215, 172), + Connection(172, 138), + Connection(218, 115), + Connection(115, 219), + Connection(219, 218), + Connection(41, 42), + Connection(42, 81), + Connection(81, 41), + Connection(5, 195), + Connection(195, 51), + Connection(51, 5), + Connection(57, 43), + Connection(43, 61), + Connection(61, 57), + Connection(208, 171), + Connection(171, 199), + Connection(199, 208), + Connection(41, 81), + Connection(81, 38), + Connection(38, 41), + Connection(224, 53), + Connection(53, 225), + Connection(225, 224), + Connection(24, 144), + Connection(144, 110), + Connection(110, 24), + Connection(105, 52), + Connection(52, 66), + Connection(66, 105), + Connection(118, 229), + Connection(229, 117), + Connection(117, 118), + Connection(227, 34), + Connection(34, 234), + Connection(234, 227), + Connection(66, 107), + Connection(107, 69), + Connection(69, 66), + Connection(10, 109), + Connection(109, 151), + Connection(151, 10), + Connection(219, 48), + Connection(48, 235), + Connection(235, 219), + Connection(183, 62), + Connection(62, 191), + Connection(191, 183), + Connection(142, 129), + Connection(129, 126), + Connection(126, 142), + Connection(116, 111), + Connection(111, 143), + Connection(143, 116), + Connection(118, 117), + Connection(117, 50), + Connection(50, 118), + Connection(223, 222), + Connection(222, 52), + Connection(52, 223), + Connection(94, 19), + Connection(19, 141), + Connection(141, 94), + Connection(222, 221), + Connection(221, 65), + Connection(65, 222), + Connection(196, 3), + Connection(3, 197), + Connection(197, 196), + Connection(45, 220), + Connection(220, 44), + Connection(44, 45), + Connection(156, 70), + Connection(70, 139), + Connection(139, 156), + Connection(188, 122), + Connection(122, 245), + Connection(245, 188), + Connection(139, 71), + Connection(71, 162), + Connection(162, 139), + Connection(149, 170), + Connection(170, 150), + Connection(150, 149), + Connection(122, 188), + Connection(188, 196), + Connection(196, 122), + Connection(206, 216), + Connection(216, 92), + Connection(92, 206), + Connection(164, 2), + Connection(2, 167), + Connection(167, 164), + Connection(242, 141), + Connection(141, 241), + Connection(241, 242), + Connection(0, 164), + Connection(164, 37), + Connection(37, 0), + Connection(11, 72), + Connection(72, 12), + Connection(12, 11), + Connection(12, 38), + Connection(38, 13), + Connection(13, 12), + Connection(70, 63), + Connection(63, 71), + Connection(71, 70), + Connection(31, 226), + Connection(226, 111), + Connection(111, 31), + Connection(36, 101), + Connection(101, 205), + Connection(205, 36), + Connection(203, 206), + Connection(206, 165), + Connection(165, 203), + Connection(126, 209), + Connection(209, 217), + Connection(217, 126), + Connection(98, 165), + Connection(165, 97), + Connection(97, 98), + Connection(237, 220), + Connection(220, 218), + Connection(218, 237), + Connection(237, 239), + Connection(239, 241), + Connection(241, 237), + Connection(210, 214), + Connection(214, 169), + Connection(169, 210), + Connection(140, 171), + Connection(171, 32), + Connection(32, 140), + Connection(241, 125), + Connection(125, 237), + Connection(237, 241), + Connection(179, 86), + Connection(86, 178), + Connection(178, 179), + Connection(180, 85), + Connection(85, 179), + Connection(179, 180), + Connection(181, 84), + Connection(84, 180), + Connection(180, 181), + Connection(182, 83), + Connection(83, 181), + Connection(181, 182), + Connection(194, 201), + Connection(201, 182), + Connection(182, 194), + Connection(177, 137), + Connection(137, 132), + Connection(132, 177), + Connection(184, 76), + Connection(76, 183), + Connection(183, 184), + Connection(185, 61), + Connection(61, 184), + Connection(184, 185), + Connection(186, 57), + Connection(57, 185), + Connection(185, 186), + Connection(216, 212), + Connection(212, 186), + Connection(186, 216), + Connection(192, 214), + Connection(214, 187), + Connection(187, 192), + Connection(139, 34), + Connection(34, 156), + Connection(156, 139), + Connection(218, 79), + Connection(79, 237), + Connection(237, 218), + Connection(147, 123), + Connection(123, 177), + Connection(177, 147), + Connection(45, 44), + Connection(44, 4), + Connection(4, 45), + Connection(208, 201), + Connection(201, 32), + Connection(32, 208), + Connection(98, 64), + Connection(64, 129), + Connection(129, 98), + Connection(192, 213), + Connection(213, 138), + Connection(138, 192), + Connection(235, 59), + Connection(59, 219), + Connection(219, 235), + Connection(141, 242), + Connection(242, 97), + Connection(97, 141), + Connection(97, 2), + Connection(2, 141), + Connection(141, 97), + Connection(240, 75), + Connection(75, 235), + Connection(235, 240), + Connection(229, 24), + Connection(24, 228), + Connection(228, 229), + Connection(31, 25), + Connection(25, 226), + Connection(226, 31), + Connection(230, 23), + Connection(23, 229), + Connection(229, 230), + Connection(231, 22), + Connection(22, 230), + Connection(230, 231), + Connection(232, 26), + Connection(26, 231), + Connection(231, 232), + Connection(233, 112), + Connection(112, 232), + Connection(232, 233), + Connection(244, 189), + Connection(189, 243), + Connection(243, 244), + Connection(189, 221), + Connection(221, 190), + Connection(190, 189), + Connection(222, 28), + Connection(28, 221), + Connection(221, 222), + Connection(223, 27), + Connection(27, 222), + Connection(222, 223), + Connection(224, 29), + Connection(29, 223), + Connection(223, 224), + Connection(225, 30), + Connection(30, 224), + Connection(224, 225), + Connection(113, 247), + Connection(247, 225), + Connection(225, 113), + Connection(99, 60), + Connection(60, 240), + Connection(240, 99), + Connection(213, 147), + Connection(147, 215), + Connection(215, 213), + Connection(60, 20), + Connection(20, 166), + Connection(166, 60), + Connection(192, 187), + Connection(187, 213), + Connection(213, 192), + Connection(243, 112), + Connection(112, 244), + Connection(244, 243), + Connection(244, 233), + Connection(233, 245), + Connection(245, 244), + Connection(245, 128), + Connection(128, 188), + Connection(188, 245), + Connection(188, 114), + Connection(114, 174), + Connection(174, 188), + Connection(134, 131), + Connection(131, 220), + Connection(220, 134), + Connection(174, 217), + Connection(217, 236), + Connection(236, 174), + Connection(236, 198), + Connection(198, 134), + Connection(134, 236), + Connection(215, 177), + Connection(177, 58), + Connection(58, 215), + Connection(156, 143), + Connection(143, 124), + Connection(124, 156), + Connection(25, 110), + Connection(110, 7), + Connection(7, 25), + Connection(31, 228), + Connection(228, 25), + Connection(25, 31), + Connection(264, 356), + Connection(356, 368), + Connection(368, 264), + Connection(0, 11), + Connection(11, 267), + Connection(267, 0), + Connection(451, 452), + Connection(452, 349), + Connection(349, 451), + Connection(267, 302), + Connection(302, 269), + Connection(269, 267), + Connection(350, 357), + Connection(357, 277), + Connection(277, 350), + Connection(350, 452), + Connection(452, 357), + Connection(357, 350), + Connection(299, 333), + Connection(333, 297), + Connection(297, 299), + Connection(396, 175), + Connection(175, 377), + Connection(377, 396), + Connection(280, 347), + Connection(347, 330), + Connection(330, 280), + Connection(269, 303), + Connection(303, 270), + Connection(270, 269), + Connection(151, 9), + Connection(9, 337), + Connection(337, 151), + Connection(344, 278), + Connection(278, 360), + Connection(360, 344), + Connection(424, 418), + Connection(418, 431), + Connection(431, 424), + Connection(270, 304), + Connection(304, 409), + Connection(409, 270), + Connection(272, 310), + Connection(310, 407), + Connection(407, 272), + Connection(322, 270), + Connection(270, 410), + Connection(410, 322), + Connection(449, 450), + Connection(450, 347), + Connection(347, 449), + Connection(432, 422), + Connection(422, 434), + Connection(434, 432), + Connection(18, 313), + Connection(313, 17), + Connection(17, 18), + Connection(291, 306), + Connection(306, 375), + Connection(375, 291), + Connection(259, 387), + Connection(387, 260), + Connection(260, 259), + Connection(424, 335), + Connection(335, 418), + Connection(418, 424), + Connection(434, 364), + Connection(364, 416), + Connection(416, 434), + Connection(391, 423), + Connection(423, 327), + Connection(327, 391), + Connection(301, 251), + Connection(251, 298), + Connection(298, 301), + Connection(275, 281), + Connection(281, 4), + Connection(4, 275), + Connection(254, 373), + Connection(373, 253), + Connection(253, 254), + Connection(375, 307), + Connection(307, 321), + Connection(321, 375), + Connection(280, 425), + Connection(425, 411), + Connection(411, 280), + Connection(200, 421), + Connection(421, 18), + Connection(18, 200), + Connection(335, 321), + Connection(321, 406), + Connection(406, 335), + Connection(321, 320), + Connection(320, 405), + Connection(405, 321), + Connection(314, 315), + Connection(315, 17), + Connection(17, 314), + Connection(423, 426), + Connection(426, 266), + Connection(266, 423), + Connection(396, 377), + Connection(377, 369), + Connection(369, 396), + Connection(270, 322), + Connection(322, 269), + Connection(269, 270), + Connection(413, 417), + Connection(417, 464), + Connection(464, 413), + Connection(385, 386), + Connection(386, 258), + Connection(258, 385), + Connection(248, 456), + Connection(456, 419), + Connection(419, 248), + Connection(298, 284), + Connection(284, 333), + Connection(333, 298), + Connection(168, 417), + Connection(417, 8), + Connection(8, 168), + Connection(448, 346), + Connection(346, 261), + Connection(261, 448), + Connection(417, 413), + Connection(413, 285), + Connection(285, 417), + Connection(326, 327), + Connection(327, 328), + Connection(328, 326), + Connection(277, 355), + Connection(355, 329), + Connection(329, 277), + Connection(309, 392), + Connection(392, 438), + Connection(438, 309), + Connection(381, 382), + Connection(382, 256), + Connection(256, 381), + Connection(279, 429), + Connection(429, 360), + Connection(360, 279), + Connection(365, 364), + Connection(364, 379), + Connection(379, 365), + Connection(355, 277), + Connection(277, 437), + Connection(437, 355), + Connection(282, 443), + Connection(443, 283), + Connection(283, 282), + Connection(281, 275), + Connection(275, 363), + Connection(363, 281), + Connection(395, 431), + Connection(431, 369), + Connection(369, 395), + Connection(299, 297), + Connection(297, 337), + Connection(337, 299), + Connection(335, 273), + Connection(273, 321), + Connection(321, 335), + Connection(348, 450), + Connection(450, 349), + Connection(349, 348), + Connection(359, 446), + Connection(446, 467), + Connection(467, 359), + Connection(283, 293), + Connection(293, 282), + Connection(282, 283), + Connection(250, 458), + Connection(458, 462), + Connection(462, 250), + Connection(300, 276), + Connection(276, 383), + Connection(383, 300), + Connection(292, 308), + Connection(308, 325), + Connection(325, 292), + Connection(283, 276), + Connection(276, 293), + Connection(293, 283), + Connection(264, 372), + Connection(372, 447), + Connection(447, 264), + Connection(346, 352), + Connection(352, 340), + Connection(340, 346), + Connection(354, 274), + Connection(274, 19), + Connection(19, 354), + Connection(363, 456), + Connection(456, 281), + Connection(281, 363), + Connection(426, 436), + Connection(436, 425), + Connection(425, 426), + Connection(380, 381), + Connection(381, 252), + Connection(252, 380), + Connection(267, 269), + Connection(269, 393), + Connection(393, 267), + Connection(421, 200), + Connection(200, 428), + Connection(428, 421), + Connection(371, 266), + Connection(266, 329), + Connection(329, 371), + Connection(432, 287), + Connection(287, 422), + Connection(422, 432), + Connection(290, 250), + Connection(250, 328), + Connection(328, 290), + Connection(385, 258), + Connection(258, 384), + Connection(384, 385), + Connection(446, 265), + Connection(265, 342), + Connection(342, 446), + Connection(386, 387), + Connection(387, 257), + Connection(257, 386), + Connection(422, 424), + Connection(424, 430), + Connection(430, 422), + Connection(445, 342), + Connection(342, 276), + Connection(276, 445), + Connection(422, 273), + Connection(273, 424), + Connection(424, 422), + Connection(306, 292), + Connection(292, 307), + Connection(307, 306), + Connection(352, 366), + Connection(366, 345), + Connection(345, 352), + Connection(268, 271), + Connection(271, 302), + Connection(302, 268), + Connection(358, 423), + Connection(423, 371), + Connection(371, 358), + Connection(327, 294), + Connection(294, 460), + Connection(460, 327), + Connection(331, 279), + Connection(279, 294), + Connection(294, 331), + Connection(303, 271), + Connection(271, 304), + Connection(304, 303), + Connection(436, 432), + Connection(432, 427), + Connection(427, 436), + Connection(304, 272), + Connection(272, 408), + Connection(408, 304), + Connection(395, 394), + Connection(394, 431), + Connection(431, 395), + Connection(378, 395), + Connection(395, 400), + Connection(400, 378), + Connection(296, 334), + Connection(334, 299), + Connection(299, 296), + Connection(6, 351), + Connection(351, 168), + Connection(168, 6), + Connection(376, 352), + Connection(352, 411), + Connection(411, 376), + Connection(307, 325), + Connection(325, 320), + Connection(320, 307), + Connection(285, 295), + Connection(295, 336), + Connection(336, 285), + Connection(320, 319), + Connection(319, 404), + Connection(404, 320), + Connection(329, 330), + Connection(330, 349), + Connection(349, 329), + Connection(334, 293), + Connection(293, 333), + Connection(333, 334), + Connection(366, 323), + Connection(323, 447), + Connection(447, 366), + Connection(316, 15), + Connection(15, 315), + Connection(315, 316), + Connection(331, 358), + Connection(358, 279), + Connection(279, 331), + Connection(317, 14), + Connection(14, 316), + Connection(316, 317), + Connection(8, 285), + Connection(285, 9), + Connection(9, 8), + Connection(277, 329), + Connection(329, 350), + Connection(350, 277), + Connection(253, 374), + Connection(374, 252), + Connection(252, 253), + Connection(319, 318), + Connection(318, 403), + Connection(403, 319), + Connection(351, 6), + Connection(6, 419), + Connection(419, 351), + Connection(324, 318), + Connection(318, 325), + Connection(325, 324), + Connection(397, 367), + Connection(367, 365), + Connection(365, 397), + Connection(288, 435), + Connection(435, 397), + Connection(397, 288), + Connection(278, 344), + Connection(344, 439), + Connection(439, 278), + Connection(310, 272), + Connection(272, 311), + Connection(311, 310), + Connection(248, 195), + Connection(195, 281), + Connection(281, 248), + Connection(375, 273), + Connection(273, 291), + Connection(291, 375), + Connection(175, 396), + Connection(396, 199), + Connection(199, 175), + Connection(312, 311), + Connection(311, 268), + Connection(268, 312), + Connection(276, 283), + Connection(283, 445), + Connection(445, 276), + Connection(390, 373), + Connection(373, 339), + Connection(339, 390), + Connection(295, 282), + Connection(282, 296), + Connection(296, 295), + Connection(448, 449), + Connection(449, 346), + Connection(346, 448), + Connection(356, 264), + Connection(264, 454), + Connection(454, 356), + Connection(337, 336), + Connection(336, 299), + Connection(299, 337), + Connection(337, 338), + Connection(338, 151), + Connection(151, 337), + Connection(294, 278), + Connection(278, 455), + Connection(455, 294), + Connection(308, 292), + Connection(292, 415), + Connection(415, 308), + Connection(429, 358), + Connection(358, 355), + Connection(355, 429), + Connection(265, 340), + Connection(340, 372), + Connection(372, 265), + Connection(352, 346), + Connection(346, 280), + Connection(280, 352), + Connection(295, 442), + Connection(442, 282), + Connection(282, 295), + Connection(354, 19), + Connection(19, 370), + Connection(370, 354), + Connection(285, 441), + Connection(441, 295), + Connection(295, 285), + Connection(195, 248), + Connection(248, 197), + Connection(197, 195), + Connection(457, 440), + Connection(440, 274), + Connection(274, 457), + Connection(301, 300), + Connection(300, 368), + Connection(368, 301), + Connection(417, 351), + Connection(351, 465), + Connection(465, 417), + Connection(251, 301), + Connection(301, 389), + Connection(389, 251), + Connection(394, 395), + Connection(395, 379), + Connection(379, 394), + Connection(399, 412), + Connection(412, 419), + Connection(419, 399), + Connection(410, 436), + Connection(436, 322), + Connection(322, 410), + Connection(326, 2), + Connection(2, 393), + Connection(393, 326), + Connection(354, 370), + Connection(370, 461), + Connection(461, 354), + Connection(393, 164), + Connection(164, 267), + Connection(267, 393), + Connection(268, 302), + Connection(302, 12), + Connection(12, 268), + Connection(312, 268), + Connection(268, 13), + Connection(13, 312), + Connection(298, 293), + Connection(293, 301), + Connection(301, 298), + Connection(265, 446), + Connection(446, 340), + Connection(340, 265), + Connection(280, 330), + Connection(330, 425), + Connection(425, 280), + Connection(322, 426), + Connection(426, 391), + Connection(391, 322), + Connection(420, 429), + Connection(429, 437), + Connection(437, 420), + Connection(393, 391), + Connection(391, 326), + Connection(326, 393), + Connection(344, 440), + Connection(440, 438), + Connection(438, 344), + Connection(458, 459), + Connection(459, 461), + Connection(461, 458), + Connection(364, 434), + Connection(434, 394), + Connection(394, 364), + Connection(428, 396), + Connection(396, 262), + Connection(262, 428), + Connection(274, 354), + Connection(354, 457), + Connection(457, 274), + Connection(317, 316), + Connection(316, 402), + Connection(402, 317), + Connection(316, 315), + Connection(315, 403), + Connection(403, 316), + Connection(315, 314), + Connection(314, 404), + Connection(404, 315), + Connection(314, 313), + Connection(313, 405), + Connection(405, 314), + Connection(313, 421), + Connection(421, 406), + Connection(406, 313), + Connection(323, 366), + Connection(366, 361), + Connection(361, 323), + Connection(292, 306), + Connection(306, 407), + Connection(407, 292), + Connection(306, 291), + Connection(291, 408), + Connection(408, 306), + Connection(291, 287), + Connection(287, 409), + Connection(409, 291), + Connection(287, 432), + Connection(432, 410), + Connection(410, 287), + Connection(427, 434), + Connection(434, 411), + Connection(411, 427), + Connection(372, 264), + Connection(264, 383), + Connection(383, 372), + Connection(459, 309), + Connection(309, 457), + Connection(457, 459), + Connection(366, 352), + Connection(352, 401), + Connection(401, 366), + Connection(1, 274), + Connection(274, 4), + Connection(4, 1), + Connection(418, 421), + Connection(421, 262), + Connection(262, 418), + Connection(331, 294), + Connection(294, 358), + Connection(358, 331), + Connection(435, 433), + Connection(433, 367), + Connection(367, 435), + Connection(392, 289), + Connection(289, 439), + Connection(439, 392), + Connection(328, 462), + Connection(462, 326), + Connection(326, 328), + Connection(94, 2), + Connection(2, 370), + Connection(370, 94), + Connection(289, 305), + Connection(305, 455), + Connection(455, 289), + Connection(339, 254), + Connection(254, 448), + Connection(448, 339), + Connection(359, 255), + Connection(255, 446), + Connection(446, 359), + Connection(254, 253), + Connection(253, 449), + Connection(449, 254), + Connection(253, 252), + Connection(252, 450), + Connection(450, 253), + Connection(252, 256), + Connection(256, 451), + Connection(451, 252), + Connection(256, 341), + Connection(341, 452), + Connection(452, 256), + Connection(414, 413), + Connection(413, 463), + Connection(463, 414), + Connection(286, 441), + Connection(441, 414), + Connection(414, 286), + Connection(286, 258), + Connection(258, 441), + Connection(441, 286), + Connection(258, 257), + Connection(257, 442), + Connection(442, 258), + Connection(257, 259), + Connection(259, 443), + Connection(443, 257), + Connection(259, 260), + Connection(260, 444), + Connection(444, 259), + Connection(260, 467), + Connection(467, 445), + Connection(445, 260), + Connection(309, 459), + Connection(459, 250), + Connection(250, 309), + Connection(305, 289), + Connection(289, 290), + Connection(290, 305), + Connection(305, 290), + Connection(290, 460), + Connection(460, 305), + Connection(401, 376), + Connection(376, 435), + Connection(435, 401), + Connection(309, 250), + Connection(250, 392), + Connection(392, 309), + Connection(376, 411), + Connection(411, 433), + Connection(433, 376), + Connection(453, 341), + Connection(341, 464), + Connection(464, 453), + Connection(357, 453), + Connection(453, 465), + Connection(465, 357), + Connection(343, 357), + Connection(357, 412), + Connection(412, 343), + Connection(437, 343), + Connection(343, 399), + Connection(399, 437), + Connection(344, 360), + Connection(360, 440), + Connection(440, 344), + Connection(420, 437), + Connection(437, 456), + Connection(456, 420), + Connection(360, 420), + Connection(420, 363), + Connection(363, 360), + Connection(361, 401), + Connection(401, 288), + Connection(288, 361), + Connection(265, 372), + Connection(372, 353), + Connection(353, 265), + Connection(390, 339), + Connection(339, 249), + Connection(249, 390), + Connection(339, 448), + Connection(448, 255), + Connection(255, 339), + ] + + @dataclasses.dataclass class FaceLandmarkerResult: """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image. diff --git a/mediapipe/tasks/python/vision/face_stylizer.py b/mediapipe/tasks/python/vision/face_stylizer.py index c6470b19f..0b10a2b40 100644 --- a/mediapipe/tasks/python/vision/face_stylizer.py +++ b/mediapipe/tasks/python/vision/face_stylizer.py @@ -176,16 +176,13 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): Only use this method when the FaceStylizer is created with the image running mode. - To ensure that the output image has reasonable quality, the stylized output - image size is the smaller of the model output size and the size of the - `region_of_interest` specified in `image_processing_options`. - Args: image: MediaPipe Image. image_processing_options: Options for image processing. Returns: - The stylized image of the most visible face. None if no face is detected + The stylized image of the most visible face. The stylized output image + size is the same as the model output size. None if no face is detected on the input image. Raises: @@ -217,17 +214,14 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): milliseconds) along with the video frame. The input timestamps should be monotonically increasing for adjacent calls of this method. - To ensure that the output image has reasonable quality, the stylized output - image size is the smaller of the model output size and the size of the - `region_of_interest` specified in `image_processing_options`. - Args: image: MediaPipe Image. timestamp_ms: The timestamp of the input video frame in milliseconds. image_processing_options: Options for image processing. Returns: - The stylized image of the most visible face. None if no face is detected + The stylized image of the most visible face. The stylized output image + size is the same as the model output size. None if no face is detected on the input image. Raises: @@ -266,12 +260,9 @@ class FaceStylizer(base_vision_task_api.BaseVisionTaskApi): images if needed. In other words, it's not guaranteed to have output per input image. - To ensure that the stylized image has reasonable quality, the stylized - output image size is the smaller of the model output size and the size of - the `region_of_interest` specified in `image_processing_options`. - The `result_callback` provides: - - The stylized image of the most visible face. None if no face is detected + - The stylized image of the most visible face. The stylized output image + size is the same as the model output size. None if no face is detected on the input image. - The input image that the face stylizer runs on. - The input timestamp in milliseconds. diff --git a/mediapipe/tasks/python/vision/image_segmenter.py b/mediapipe/tasks/python/vision/image_segmenter.py index e50ffbf79..5d9af86ce 100644 --- a/mediapipe/tasks/python/vision/image_segmenter.py +++ b/mediapipe/tasks/python/vision/image_segmenter.py @@ -14,13 +14,13 @@ """MediaPipe image segmenter task.""" import dataclasses -import enum from typing import Callable, List, Mapping, Optional from mediapipe.python import packet_creator from mediapipe.python import packet_getter from mediapipe.python._framework_bindings import image as image_module from mediapipe.python._framework_bindings import packet +from mediapipe.tasks.cc.vision.image_segmenter.calculators import tensors_to_segmentation_calculator_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import image_segmenter_graph_options_pb2 from mediapipe.tasks.cc.vision.image_segmenter.proto import segmenter_options_pb2 from mediapipe.tasks.python.components.containers import rect @@ -31,28 +31,50 @@ from mediapipe.tasks.python.vision.core import base_vision_task_api from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module from mediapipe.tasks.python.vision.core import vision_task_running_mode -ImageSegmenterResult = List[image_module.Image] _NormalizedRect = rect.NormalizedRect _BaseOptions = base_options_module.BaseOptions _SegmenterOptionsProto = segmenter_options_pb2.SegmenterOptions _ImageSegmenterGraphOptionsProto = ( image_segmenter_graph_options_pb2.ImageSegmenterGraphOptions ) +TensorsToSegmentationCalculatorOptionsProto = ( + tensors_to_segmentation_calculator_pb2.TensorsToSegmentationCalculatorOptions +) _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' -_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' +_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' +_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' +_CATEGORY_MASK_STREAM_NAME = 'category_mask' +_CATEGORY_MASK_TAG = 'CATEGORY_MASK' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _IMAGE_TAG = 'IMAGE' _NORM_RECT_STREAM_NAME = 'norm_rect_in' _NORM_RECT_TAG = 'NORM_RECT' +_TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = ( + 'mediapipe.tasks.TensorsToSegmentationCalculator' +) _TASK_GRAPH_NAME = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph' _MICRO_SECONDS_PER_MILLISECOND = 1000 +@dataclasses.dataclass +class ImageSegmenterResult: + """Output result of ImageSegmenter. + + confidence_masks: multiple masks of float image where, for each mask, each + pixel represents the prediction confidence, usually in the [0, 1] range. + + category_mask: a category mask of uint8 image where each pixel represents the + class which the pixel in the original image was predicted to belong to. + """ + + confidence_masks: Optional[List[image_module.Image]] = None + category_mask: Optional[image_module.Image] = None + + @dataclasses.dataclass class ImageSegmenterOptions: """Options for the image segmenter task. @@ -64,28 +86,17 @@ class ImageSegmenterOptions: objects on single image inputs. 2) The video mode for segmenting objects on the decoded frames of a video. 3) The live stream mode for segmenting objects on a live stream of input data, such as from camera. - output_type: The output mask type allows specifying the type of - post-processing to perform on the raw model results. - activation: Activation function to apply to input tensor. + output_confidence_masks: Whether to output confidence masks. + output_category_mask: Whether to output category mask. result_callback: The user-defined result callback for processing live stream data. The result callback should only be specified when the running mode is set to the live stream mode. """ - class OutputType(enum.Enum): - UNSPECIFIED = 0 - CATEGORY_MASK = 1 - CONFIDENCE_MASK = 2 - - class Activation(enum.Enum): - NONE = 0 - SIGMOID = 1 - SOFTMAX = 2 - base_options: _BaseOptions running_mode: _RunningMode = _RunningMode.IMAGE - output_type: Optional[OutputType] = OutputType.CATEGORY_MASK - activation: Optional[Activation] = Activation.NONE + output_confidence_masks: bool = True + output_category_mask: bool = False result_callback: Optional[ Callable[[ImageSegmenterResult, image_module.Image, int], None] ] = None @@ -97,9 +108,7 @@ class ImageSegmenterOptions: base_options_proto.use_stream_mode = ( False if self.running_mode == _RunningMode.IMAGE else True ) - segmenter_options_proto = _SegmenterOptionsProto( - output_type=self.output_type.value, activation=self.activation.value - ) + segmenter_options_proto = _SegmenterOptionsProto() return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto, @@ -122,8 +131,8 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): Output tensors: (kTfLiteUInt8/kTfLiteFloat32) - list of segmented masks. - - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. - - if `output_type` is CONFIDENCE_MASK, float32 Image list of size + - if `output_category_mask` is True, uint8 Image, Image vector of size 1. + - if `output_confidence_masks` is True, float32 Image list of size `channels`. - batch is always 1 @@ -131,6 +140,41 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/2 """ + def __init__(self, graph_config, running_mode, packet_callback) -> None: + """Initializes the `ImageSegmenter` object.""" + super(ImageSegmenter, self).__init__( + graph_config, running_mode, packet_callback + ) + self._populate_labels() + + def _populate_labels(self) -> None: + """Populate the labelmap in TensorsToSegmentationCalculator to labels field. + + Raises: + ValueError if there is an error during finding + TensorsToSegmentationCalculator. + """ + self._labels = [] + graph_config = self._runner.get_graph_config() + found_tensors_to_segmentation = False + + for node in graph_config.node: + if _TENSORS_TO_SEGMENTATION_CALCULATOR_NAME in node.name: + if found_tensors_to_segmentation: + raise ValueError( + 'The graph has more than one ' + f'{_TENSORS_TO_SEGMENTATION_CALCULATOR_NAME}.' + ) + found_tensors_to_segmentation = True + options = node.options.Extensions[ + TensorsToSegmentationCalculatorOptionsProto.ext + ] + if options.label_items: + for i in range(len(options.label_items)): + if i not in options.label_items: + raise ValueError(f'The labelmap has no expected key: {i}.') + self._labels.append(options.label_items[i].name) + @classmethod def create_from_model_path(cls, model_path: str) -> 'ImageSegmenter': """Creates an `ImageSegmenter` object from a TensorFlow Lite model and the default `ImageSegmenterOptions`. @@ -177,27 +221,48 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): def packets_callback(output_packets: Mapping[str, packet.Packet]): if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): return - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + + segmentation_result = ImageSegmenterResult() + + if options.output_confidence_masks: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if options.output_category_mask: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) - timestamp = output_packets[_SEGMENTATION_OUT_STREAM_NAME].timestamp + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp options.result_callback( segmentation_result, image, timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, ) + output_streams = [ + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_confidence_masks: + output_streams.append( + ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) + ) + + if options.output_category_mask: + output_streams.append( + ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) + ) + task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, input_streams=[ ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], - output_streams=[ - ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), - ], + output_streams=output_streams, task_options=options, ) return cls( @@ -240,9 +305,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = ImageSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result def segment_for_video( @@ -285,9 +359,18 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = ImageSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result def segment_async( @@ -334,3 +417,17 @@ class ImageSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), }) + + @property + def labels(self) -> List[str]: + """Get the category label list the ImageSegmenter can recognize. + + For CATEGORY_MASK type, the index in the category mask corresponds to the + category in the label list. + For CONFIDENCE_MASK type, the output mask list at index corresponds to the + category in the label list. + + If there is no label map provided in the model file, empty label list is + returned. + """ + return self._labels diff --git a/mediapipe/tasks/python/vision/interactive_segmenter.py b/mediapipe/tasks/python/vision/interactive_segmenter.py index 12a30b6ef..ad93c798c 100644 --- a/mediapipe/tasks/python/vision/interactive_segmenter.py +++ b/mediapipe/tasks/python/vision/interactive_segmenter.py @@ -41,8 +41,10 @@ _RunningMode = vision_task_running_mode.VisionTaskRunningMode _ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions _TaskInfo = task_info_module.TaskInfo -_SEGMENTATION_OUT_STREAM_NAME = 'segmented_mask_out' -_SEGMENTATION_TAG = 'GROUPED_SEGMENTATION' +_CONFIDENCE_MASKS_STREAM_NAME = 'confidence_masks' +_CONFIDENCE_MASKS_TAG = 'CONFIDENCE_MASKS' +_CATEGORY_MASK_STREAM_NAME = 'category_mask' +_CATEGORY_MASK_TAG = 'CATEGORY_MASK' _IMAGE_IN_STREAM_NAME = 'image_in' _IMAGE_OUT_STREAM_NAME = 'image_out' _ROI_STREAM_NAME = 'roi_in' @@ -55,32 +57,41 @@ _TASK_GRAPH_NAME = ( ) +@dataclasses.dataclass +class InteractiveSegmenterResult: + """Output result of InteractiveSegmenter. + + confidence_masks: multiple masks of float image where, for each mask, each + pixel represents the prediction confidence, usually in the [0, 1] range. + + category_mask: a category mask of uint8 image where each pixel represents the + class which the pixel in the original image was predicted to belong to. + """ + + confidence_masks: Optional[List[image_module.Image]] = None + category_mask: Optional[image_module.Image] = None + + @dataclasses.dataclass class InteractiveSegmenterOptions: """Options for the interactive segmenter task. Attributes: base_options: Base options for the interactive segmenter task. - output_type: The output mask type allows specifying the type of - post-processing to perform on the raw model results. + output_confidence_masks: Whether to output confidence masks. + output_category_mask: Whether to output category mask. """ - class OutputType(enum.Enum): - UNSPECIFIED = 0 - CATEGORY_MASK = 1 - CONFIDENCE_MASK = 2 - base_options: _BaseOptions - output_type: Optional[OutputType] = OutputType.CATEGORY_MASK + output_confidence_masks: bool = True + output_category_mask: bool = False @doc_controls.do_not_generate_docs def to_pb2(self) -> _ImageSegmenterGraphOptionsProto: """Generates an InteractiveSegmenterOptions protobuf object.""" base_options_proto = self.base_options.to_pb2() base_options_proto.use_stream_mode = False - segmenter_options_proto = _SegmenterOptionsProto( - output_type=self.output_type.value - ) + segmenter_options_proto = _SegmenterOptionsProto() return _ImageSegmenterGraphOptionsProto( base_options=base_options_proto, segmenter_options=segmenter_options_proto, @@ -192,6 +203,20 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): RuntimeError: If other types of error occurred. """ + output_streams = [ + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_confidence_masks: + output_streams.append( + ':'.join([_CONFIDENCE_MASKS_TAG, _CONFIDENCE_MASKS_STREAM_NAME]) + ) + + if options.output_category_mask: + output_streams.append( + ':'.join([_CATEGORY_MASK_TAG, _CATEGORY_MASK_STREAM_NAME]) + ) + task_info = _TaskInfo( task_graph=_TASK_GRAPH_NAME, input_streams=[ @@ -199,10 +224,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): ':'.join([_ROI_TAG, _ROI_STREAM_NAME]), ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), ], - output_streams=[ - ':'.join([_SEGMENTATION_TAG, _SEGMENTATION_OUT_STREAM_NAME]), - ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), - ], + output_streams=output_streams, task_options=options, ) return cls( @@ -216,7 +238,7 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): image: image_module.Image, roi: RegionOfInterest, image_processing_options: Optional[_ImageProcessingOptions] = None, - ) -> List[image_module.Image]: + ) -> InteractiveSegmenterResult: """Performs the actual segmentation task on the provided MediaPipe Image. The image can be of any size with format RGB. @@ -248,7 +270,16 @@ class InteractiveSegmenter(base_vision_task_api.BaseVisionTaskApi): normalized_rect.to_pb2() ), }) - segmentation_result = packet_getter.get_image_list( - output_packets[_SEGMENTATION_OUT_STREAM_NAME] - ) + segmentation_result = InteractiveSegmenterResult() + + if _CONFIDENCE_MASKS_STREAM_NAME in output_packets: + segmentation_result.confidence_masks = packet_getter.get_image_list( + output_packets[_CONFIDENCE_MASKS_STREAM_NAME] + ) + + if _CATEGORY_MASK_STREAM_NAME in output_packets: + segmentation_result.category_mask = packet_getter.get_image( + output_packets[_CATEGORY_MASK_STREAM_NAME] + ) + return segmentation_result diff --git a/mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task b/mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task deleted file mode 100644 index c62854c0c..000000000 Binary files a/mediapipe/tasks/testdata/core/dummy_gesture_recognizer.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/face_landmarker.task b/mediapipe/tasks/testdata/vision/face_landmarker.task deleted file mode 100644 index d6b5f8835..000000000 Binary files a/mediapipe/tasks/testdata/vision/face_landmarker.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/face_landmarker_v2.task b/mediapipe/tasks/testdata/vision/face_landmarker_v2.task deleted file mode 100644 index 885f6e31d..000000000 Binary files a/mediapipe/tasks/testdata/vision/face_landmarker_v2.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/face_landmarker_v2_with_blendshapes.task b/mediapipe/tasks/testdata/vision/face_landmarker_v2_with_blendshapes.task deleted file mode 100644 index 7749e045c..000000000 Binary files a/mediapipe/tasks/testdata/vision/face_landmarker_v2_with_blendshapes.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/face_landmarker_with_blendshapes.task b/mediapipe/tasks/testdata/vision/face_landmarker_with_blendshapes.task deleted file mode 100644 index 04adf1841..000000000 Binary files a/mediapipe/tasks/testdata/vision/face_landmarker_with_blendshapes.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/gesture_recognizer_with_custom_classifier.task b/mediapipe/tasks/testdata/vision/gesture_recognizer_with_custom_classifier.task deleted file mode 100644 index 3c1da7b3d..000000000 Binary files a/mediapipe/tasks/testdata/vision/gesture_recognizer_with_custom_classifier.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/hand_gesture_recognizer_with_custom_classifier.task b/mediapipe/tasks/testdata/vision/hand_gesture_recognizer_with_custom_classifier.task deleted file mode 100644 index 1390ca88d..000000000 Binary files a/mediapipe/tasks/testdata/vision/hand_gesture_recognizer_with_custom_classifier.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/hand_landmarker.task b/mediapipe/tasks/testdata/vision/hand_landmarker.task deleted file mode 100644 index 748b2f013..000000000 Binary files a/mediapipe/tasks/testdata/vision/hand_landmarker.task and /dev/null differ diff --git a/mediapipe/tasks/testdata/vision/pose_landmarker.task b/mediapipe/tasks/testdata/vision/pose_landmarker.task deleted file mode 100644 index 598959b84..000000000 Binary files a/mediapipe/tasks/testdata/vision/pose_landmarker.task and /dev/null differ diff --git a/mediapipe/tasks/web/audio/README.md b/mediapipe/tasks/web/audio/README.md index 834785709..ed2543c7a 100644 --- a/mediapipe/tasks/web/audio/README.md +++ b/mediapipe/tasks/web/audio/README.md @@ -13,7 +13,7 @@ const audio = await FilesetResolver.forAudioTasks( const audioClassifier = await AudioClassifier.createFromModelPath(audio, "https://storage.googleapis.com/mediapipe-tasks/audio_classifier/yamnet_audio_classifier_with_metadata.tflite" ); -const classifications = audioClassifier.classifiy(audioData); +const classifications = audioClassifier.classify(audioData); ``` ## Audio Embedding diff --git a/mediapipe/tasks/web/components/processors/BUILD b/mediapipe/tasks/web/components/processors/BUILD index a5f93a147..d81fbc79a 100644 --- a/mediapipe/tasks/web/components/processors/BUILD +++ b/mediapipe/tasks/web/components/processors/BUILD @@ -125,3 +125,27 @@ jasmine_node_test( name = "embedder_options_test", deps = [":embedder_options_test_lib"], ) + +mediapipe_ts_library( + name = "landmark_result", + srcs = [ + "landmark_result.ts", + "landmark_result_test_lib.ts", + ], + deps = [ + "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/components/containers:landmark", + ], +) + +mediapipe_ts_library( + name = "landmark_result_test_lib", + testonly = True, + srcs = ["landmark_result.test.ts"], + deps = [":landmark_result"], +) + +jasmine_node_test( + name = "landmark_result_test", + deps = [":landmark_result_test_lib"], +) diff --git a/mediapipe/tasks/web/components/processors/landmark_result.test.ts b/mediapipe/tasks/web/components/processors/landmark_result.test.ts new file mode 100644 index 000000000..3a2635107 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/landmark_result.test.ts @@ -0,0 +1,52 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result'; +import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib'; + + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +describe('convertToLandmarks()', () => { + it('transforms custom values', () => { + const landmarkListProto = createLandmarks(0.1, 0.2, 0.3); + const result = convertToLandmarks(landmarkListProto); + expect(result).toEqual([{x: 0.1, y: 0.2, z: 0.3}]); + }); + + it('transforms default values', () => { + const landmarkListProto = createLandmarks(); + const result = convertToLandmarks(landmarkListProto); + expect(result).toEqual([{x: 0, y: 0, z: 0}]); + }); +}); + +describe('convertToWorldLandmarks()', () => { + it('transforms custom values', () => { + const worldLandmarkListProto = createWorldLandmarks(10, 20, 30); + const result = convertToWorldLandmarks(worldLandmarkListProto); + expect(result).toEqual([{x: 10, y: 20, z: 30}]); + }); + + it('transforms default values', () => { + const worldLandmarkListProto = createWorldLandmarks(); + const result = convertToWorldLandmarks(worldLandmarkListProto); + expect(result).toEqual([{x: 0, y: 0, z: 0}]); + }); +}); diff --git a/mediapipe/tasks/web/components/processors/landmark_result.ts b/mediapipe/tasks/web/components/processors/landmark_result.ts new file mode 100644 index 000000000..3a4fa0245 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/landmark_result.ts @@ -0,0 +1,45 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {LandmarkList as LandmarkListProto, NormalizedLandmarkList as NormalizedLandmarkListProto} from '../../../../framework/formats/landmark_pb'; +import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; + +/** Converts raw data into a landmark. */ +export function convertToLandmarks(proto: NormalizedLandmarkListProto): + NormalizedLandmark[] { + const landmarks: NormalizedLandmark[] = []; + for (const landmark of proto.getLandmarkList()) { + landmarks.push({ + x: landmark.getX() ?? 0, + y: landmark.getY() ?? 0, + z: landmark.getZ() ?? 0, + }); + } + return landmarks; +} + +/** Converts raw data into a world landmark. */ +export function convertToWorldLandmarks(proto: LandmarkListProto): Landmark[] { + const worldLandmarks: Landmark[] = []; + for (const worldLandmark of proto.getLandmarkList()) { + worldLandmarks.push({ + x: worldLandmark.getX() ?? 0, + y: worldLandmark.getY() ?? 0, + z: worldLandmark.getZ() ?? 0, + }); + } + return worldLandmarks; +} diff --git a/mediapipe/tasks/web/components/processors/landmark_result_test_lib.ts b/mediapipe/tasks/web/components/processors/landmark_result_test_lib.ts new file mode 100644 index 000000000..318ab2f63 --- /dev/null +++ b/mediapipe/tasks/web/components/processors/landmark_result_test_lib.ts @@ -0,0 +1,44 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Landmark as LandmarkProto, LandmarkList as LandmarkListProto, NormalizedLandmark as NormalizedLandmarkProto, NormalizedLandmarkList as NormalizedLandmarkListProto} from '../../../../framework/formats/landmark_pb'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** Creates a normalized landmark list with one entrry. */ +export function createLandmarks( + x?: number, y?: number, z?: number): NormalizedLandmarkListProto { + const landmarksProto = new NormalizedLandmarkListProto(); + const landmark = new NormalizedLandmarkProto(); + if (x !== undefined) landmark.setX(x); + if (y !== undefined) landmark.setY(y); + if (z !== undefined) landmark.setZ(z); + landmarksProto.addLandmark(landmark); + return landmarksProto; +} + +/** Creates a world landmark list with one entry. */ +export function createWorldLandmarks( + x?: number, y?: number, z?: number): LandmarkListProto { + const worldLandmarksProto = new LandmarkListProto(); + const landmark = new LandmarkProto(); + if (x !== undefined) landmark.setX(x); + if (y !== undefined) landmark.setY(y); + if (z !== undefined) landmark.setZ(z); + worldLandmarksProto.addLandmark(landmark); + return worldLandmarksProto; +} diff --git a/mediapipe/tasks/web/text/README.md b/mediapipe/tasks/web/text/README.md index 089894653..4a26f5b9d 100644 --- a/mediapipe/tasks/web/text/README.md +++ b/mediapipe/tasks/web/text/README.md @@ -28,7 +28,7 @@ const text = await FilesetResolver.forTextTasks( const textClassifier = await TextClassifier.createFromModelPath(text, "https://storage.googleapis.com/mediapipe-tasks/text_classifier/bert_text_classifier.tflite" ); -const classifications = textClassifier.classifiy(textData); +const classifications = textClassifier.classify(textData); ``` For more information, refer to the [Text Classification](https://developers.google.com/mediapipe/solutions/text/text_classifier/web_js) documentation. diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index c86801955..fa1ed32da 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -20,6 +20,7 @@ mediapipe_files(srcs = [ VISION_LIBS = [ "//mediapipe/tasks/web/core:fileset_resolver", "//mediapipe/tasks/web/vision/core:drawing_utils", + "//mediapipe/tasks/web/vision/core:image", "//mediapipe/tasks/web/vision/face_detector", "//mediapipe/tasks/web/vision/face_landmarker", "//mediapipe/tasks/web/vision/face_stylizer", diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index d5109142b..6423807fc 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -12,7 +12,7 @@ const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); const faceDetector = await FaceDetector.createFromModelPath(vision, - "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" + "https://storage.googleapis.com/mediapipe-tasks/face_detector/face_detection_short_range.tflite" ); const image = document.getElementById("image") as HTMLImageElement; const detections = faceDetector.detect(image); @@ -29,7 +29,7 @@ const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); const faceLandmarker = await FaceLandmarker.createFromModelPath(vision, - "model.task" + "https://storage.googleapis.com/mediapipe-tasks/face_landmarker/face_landmarker.task" ); const image = document.getElementById("image") as HTMLImageElement; const landmarks = faceLandmarker.detect(image); @@ -44,7 +44,7 @@ const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); const faceStylizer = await FaceStylizer.createFromModelPath(vision, - "model.tflite" + "https://storage.googleapis.com/mediapipe-tasks/face_stylizer/face_stylizer_with_metadata.tflite" ); const image = document.getElementById("image") as HTMLImageElement; const stylizedImage = faceStylizer.stylize(image); @@ -115,7 +115,7 @@ const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); const imageSegmenter = await ImageSegmenter.createFromModelPath(vision, - "model.tflite" + "https://storage.googleapis.com/mediapipe-tasks/image_segmenter/selfie_segmentation.tflite" ); const image = document.getElementById("image") as HTMLImageElement; imageSegmenter.segment(image, (masks, width, height) => { @@ -133,7 +133,8 @@ const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); const interactiveSegmenter = await InteractiveSegmenter.createFromModelPath( - vision, "model.tflite" + vision, + "https://storage.googleapis.com/mediapipe-tasks/interactive_segmenter/ptm_512_hdt_ptm_woid.tflite ); const image = document.getElementById("image") as HTMLImageElement; interactiveSegmenter.segment(image, { keypoint: { x: 0.1, y: 0.2 } }, diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index 8f53dc2cb..f010a8bdd 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -39,6 +39,23 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "image", + srcs = ["image.ts"], +) + +mediapipe_ts_library( + name = "image_test_lib", + testonly = True, + srcs = ["image.test.ts"], + deps = [":image"], +) + +jasmine_node_test( + name = "image_test", + deps = [":image_test_lib"], +) + mediapipe_ts_library( name = "vision_task_runner", srcs = ["vision_task_runner.ts"], diff --git a/mediapipe/tasks/web/vision/core/image.test.ts b/mediapipe/tasks/web/vision/core/image.test.ts new file mode 100644 index 000000000..7373ea385 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/image.test.ts @@ -0,0 +1,287 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import {MPImage, MPImageShaderContext, MPImageStorageType} from './image'; + +const WIDTH = 2; +const HEIGHT = 2; + +const skip = typeof document === 'undefined'; +if (skip) { + console.log('These tests must be run in a browser.'); +} + +/** The image types supported by MPImage. */ +type ImageType = ImageData|ImageBitmap|WebGLTexture; + +async function createTestData( + gl: WebGL2RenderingContext, data: number[], width: number, + height: number): Promise<[ImageData, ImageBitmap, WebGLTexture]> { + const imageData = new ImageData(new Uint8ClampedArray(data), width, height); + const imageBitmap = await createImageBitmap(imageData); + const webGlTexture = gl.createTexture()!; + + gl.bindTexture(gl.TEXTURE_2D, webGlTexture); + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, imageBitmap); + gl.bindTexture(gl.TEXTURE_2D, null); + + return [imageData, imageBitmap, webGlTexture]; +} + +(skip ? xdescribe : describe)('MPImage', () => { + let canvas: OffscreenCanvas; + let gl: WebGL2RenderingContext; + let imageData: ImageData; + let imageBitmap: ImageBitmap; + let webGlTexture: WebGLTexture; + + beforeEach(async () => { + canvas = new OffscreenCanvas(WIDTH, HEIGHT); + gl = canvas.getContext('webgl2') as WebGL2RenderingContext; + + const images = await createTestData( + gl, [1, 0, 0, 255, 2, 0, 0, 255, 3, 0, 0, 255, 4, 0, 0, 255], WIDTH, + HEIGHT); + imageData = images[0]; + imageBitmap = images[1]; + webGlTexture = images[2]; + }); + + afterEach(() => { + gl.deleteTexture(webGlTexture); + imageBitmap.close(); + }); + + function readPixelsFromImageBitmap(imageBitmap: ImageBitmap): ImageData { + const canvas = new OffscreenCanvas(imageBitmap.width, imageBitmap.height); + const ctx = canvas.getContext('2d') as OffscreenCanvasRenderingContext2D; + ctx.drawImage(imageBitmap, 0, 0); + return ctx.getImageData(0, 0, imageBitmap.width, imageBitmap.height); + } + + function readPixelsFromWebGLTexture(texture: WebGLTexture): Uint8Array { + const pixels = new Uint8Array(WIDTH * WIDTH * 4); + + gl.bindTexture(gl.TEXTURE_2D, texture); + + const framebuffer = gl.createFramebuffer()!; + gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); + gl.framebufferTexture2D( + gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); + gl.readPixels(0, 0, WIDTH, HEIGHT, gl.RGBA, gl.UNSIGNED_BYTE, pixels); + gl.bindFramebuffer(gl.FRAMEBUFFER, null); + gl.deleteFramebuffer(framebuffer); + + gl.bindTexture(gl.TEXTURE_2D, null); + + return pixels; + } + + function assertEquality(image: MPImage, expected: ImageType): void { + if (expected instanceof ImageData) { + const result = image.getImage(MPImageStorageType.IMAGE_DATA); + expect(result).toEqual(expected); + } else if (expected instanceof ImageBitmap) { + const result = image.getImage(MPImageStorageType.IMAGE_BITMAP); + expect(readPixelsFromImageBitmap(result)) + .toEqual(readPixelsFromImageBitmap(expected)); + } else { // WebGLTexture + const result = image.getImage(MPImageStorageType.WEBGL_TEXTURE); + expect(readPixelsFromWebGLTexture(result)) + .toEqual(readPixelsFromWebGLTexture(expected)); + } + } + + function createImage( + shaderContext: MPImageShaderContext, input: ImageType, width: number, + height: number): MPImage { + return new MPImage( + input instanceof ImageData ? input : null, + input instanceof ImageBitmap ? input : null, + input instanceof WebGLTexture ? input : null, + /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, canvas, + shaderContext, width, height); + } + + function runConversionTest( + input: ImageType, output: ImageType, width = WIDTH, + height = HEIGHT): void { + const shaderContext = new MPImageShaderContext(); + const image = createImage(shaderContext, input, width, height); + assertEquality(image, output); + image.close(); + shaderContext.close(); + } + + function runCloneTest(input: ImageType): void { + const shaderContext = new MPImageShaderContext(); + const image = createImage(shaderContext, input, WIDTH, HEIGHT); + const clone = image.clone(); + assertEquality(clone, input); + clone.close(); + shaderContext.close(); + } + + it(`converts from ImageData to ImageData`, () => { + runConversionTest(imageData, imageData); + }); + + it(`converts from ImageData to ImageBitmap`, () => { + runConversionTest(imageData, imageBitmap); + }); + + it(`converts from ImageData to WebGLTexture`, () => { + runConversionTest(imageData, webGlTexture); + }); + + it(`converts from ImageBitmap to ImageData`, () => { + runConversionTest(imageBitmap, imageData); + }); + + it(`converts from ImageBitmap to ImageBitmap`, () => { + runConversionTest(imageBitmap, imageBitmap); + }); + + it(`converts from ImageBitmap to WebGLTexture`, () => { + runConversionTest(imageBitmap, webGlTexture); + }); + + it(`converts from WebGLTexture to ImageData`, () => { + runConversionTest(webGlTexture, imageData); + }); + + it(`converts from WebGLTexture to ImageBitmap`, () => { + runConversionTest(webGlTexture, imageBitmap); + }); + + it(`converts from WebGLTexture to WebGLTexture`, () => { + runConversionTest(webGlTexture, webGlTexture); + }); + + it(`clones ImageData`, () => { + runCloneTest(imageData); + }); + + it(`clones ImageBitmap`, () => { + runCloneTest(imageBitmap); + }); + + it(`clones WebGLTextures`, () => { + runCloneTest(webGlTexture); + }); + + it(`does not flip textures twice`, async () => { + const [imageData, , webGlTexture] = await createTestData( + gl, [1, 0, 0, 255, 2, 0, 0, 255, 3, 0, 0, 255, 4, 0, 0, 255], WIDTH, + HEIGHT); + + const shaderContext = new MPImageShaderContext(); + const image = new MPImage( + /* imageData= */ null, /* imageBitmap= */ null, webGlTexture, + /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, canvas, + shaderContext, WIDTH, HEIGHT); + + const result = image.clone().getImage(MPImageStorageType.IMAGE_DATA); + expect(result).toEqual(imageData); + + gl.deleteTexture(webGlTexture); + shaderContext.close(); + }); + + it(`can clone and get image`, async () => { + const [imageData, , webGlTexture] = await createTestData( + gl, [1, 0, 0, 255, 2, 0, 0, 255, 3, 0, 0, 255, 4, 0, 0, 255], WIDTH, + HEIGHT); + + const shaderContext = new MPImageShaderContext(); + const image = new MPImage( + /* imageData= */ null, /* imageBitmap= */ null, webGlTexture, + /* ownsImageBitmap= */ false, /* ownsWebGLTexture= */ false, canvas, + shaderContext, WIDTH, HEIGHT); + + // Verify that we can mix the different shader modes by running them out of + // order. + let result = image.getImage(MPImageStorageType.IMAGE_DATA); + expect(result).toEqual(imageData); + + result = image.clone().getImage(MPImageStorageType.IMAGE_DATA); + expect(result).toEqual(imageData); + + result = image.getImage(MPImageStorageType.IMAGE_DATA); + expect(result).toEqual(imageData); + + gl.deleteTexture(webGlTexture); + shaderContext.close(); + }); + + it('supports hasType()', async () => { + const shaderContext = new MPImageShaderContext(); + const image = createImage(shaderContext, imageData, WIDTH, HEIGHT); + + expect(image.hasType(MPImageStorageType.IMAGE_DATA)).toBe(true); + expect(image.hasType(MPImageStorageType.WEBGL_TEXTURE)).toBe(false); + expect(image.hasType(MPImageStorageType.IMAGE_BITMAP)).toBe(false); + + image.getImage(MPImageStorageType.WEBGL_TEXTURE); + + expect(image.hasType(MPImageStorageType.IMAGE_DATA)).toBe(true); + expect(image.hasType(MPImageStorageType.WEBGL_TEXTURE)).toBe(true); + expect(image.hasType(MPImageStorageType.IMAGE_BITMAP)).toBe(false); + + await image.getImage(MPImageStorageType.IMAGE_BITMAP); + + expect(image.hasType(MPImageStorageType.IMAGE_DATA)).toBe(true); + expect(image.hasType(MPImageStorageType.WEBGL_TEXTURE)).toBe(true); + expect(image.hasType(MPImageStorageType.IMAGE_BITMAP)).toBe(true); + + image.close(); + shaderContext.close(); + }); + + it('supports image that is smaller than the canvas', async () => { + const [imageData, imageBitmap, webGlTexture] = await createTestData( + gl, [1, 0, 0, 255, 2, 0, 0, 255], /* width= */ 2, /* height= */ 1); + + runConversionTest(imageData, webGlTexture, /* width= */ 2, /* height= */ 1); + runConversionTest( + webGlTexture, imageBitmap, /* width= */ 2, /* height= */ 1); + runConversionTest(imageBitmap, imageData, /* width= */ 2, /* height= */ 1); + + gl.deleteTexture(webGlTexture); + imageBitmap.close(); + }); + + it('supports image that is larger than the canvas', async () => { + const [imageData, imageBitmap, webGlTexture] = await createTestData( + gl, + [ + 1, 0, 0, 255, 2, 0, 0, 255, 3, 0, 0, 255, + 4, 0, 0, 255, 5, 0, 0, 255, 6, 0, 0, 255 + ], + /* width= */ 2, /* height= */ 3); + + runConversionTest(imageData, webGlTexture, /* width= */ 2, /* height= */ 3); + runConversionTest( + webGlTexture, imageBitmap, /* width= */ 2, /* height= */ 3); + runConversionTest(imageBitmap, imageData, /* width= */ 2, /* height= */ 3); + + gl.deleteTexture(webGlTexture); + imageBitmap.close(); + }); +}); diff --git a/mediapipe/tasks/web/vision/core/image.ts b/mediapipe/tasks/web/vision/core/image.ts new file mode 100644 index 000000000..a4bbdfe1e --- /dev/null +++ b/mediapipe/tasks/web/vision/core/image.ts @@ -0,0 +1,595 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The underlying type of the image. */ +export enum MPImageStorageType { + /** Represents the native `ImageData` type. */ + IMAGE_DATA, + /** Represents the native `ImageBitmap` type. */ + IMAGE_BITMAP, + /** Represents the native `WebGLTexture` type. */ + WEBGL_TEXTURE +} + +type MPImageNativeContainer = ImageData|ImageBitmap|WebGLTexture; + +const VERTEX_SHADER = ` + attribute vec2 aVertex; + attribute vec2 aTex; + varying vec2 vTex; + void main(void) { + gl_Position = vec4(aVertex, 0.0, 1.0); + vTex = aTex; + }`; + +const FRAGMENT_SHADER = ` + precision mediump float; + varying vec2 vTex; + uniform sampler2D inputTexture; + void main() { + gl_FragColor = texture2D(inputTexture, vTex); + } + `; + +function assertNotNull(value: T|null, msg: string): T { + if (value === null) { + throw new Error(`Unable to obtain required WebGL resource: ${msg}`); + } + return value; +} + +/** + * Utility class that encapsulates the buffers used by `MPImageShaderContext`. + */ +class MPImageShaderBuffers { + constructor( + private readonly gl: WebGL2RenderingContext, + private readonly vertexArrayObject: WebGLVertexArrayObject, + private readonly vertexBuffer: WebGLBuffer, + private readonly textureBuffer: WebGLBuffer) {} + + bind() { + this.gl.bindVertexArray(this.vertexArrayObject); + } + + unbind() { + this.gl.bindVertexArray(null); + } + + close() { + this.gl.deleteVertexArray(this.vertexArrayObject); + this.gl.deleteBuffer(this.vertexBuffer); + this.gl.deleteBuffer(this.textureBuffer); + } +} + +/** + * A class that encapsulates the shaders used by an MPImage. Can be re-used + * across MPImages that use the same WebGL2Rendering context. + */ +export class MPImageShaderContext { + private gl?: WebGL2RenderingContext; + private framebuffer?: WebGLFramebuffer; + private program?: WebGLProgram; + private vertexShader?: WebGLShader; + private fragmentShader?: WebGLShader; + private aVertex?: GLint; + private aTex?: GLint; + + /** + * The shader buffers used for passthrough renders that don't modify the + * input texture. + */ + private shaderBuffersPassthrough?: MPImageShaderBuffers; + + /** + * The shader buffers used for passthrough renders that flip the input texture + * vertically before conversion to a different type. This is used to flip the + * texture to the expected orientation for drawing in the browser. + */ + private shaderBuffersFlipVertically?: MPImageShaderBuffers; + + private compileShader(source: string, type: number): WebGLShader { + const gl = this.gl!; + const shader = + assertNotNull(gl.createShader(type), 'Failed to create WebGL shader'); + gl.shaderSource(shader, source); + gl.compileShader(shader); + if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) { + const info = gl.getShaderInfoLog(shader); + throw new Error(`Could not compile WebGL shader: ${info}`); + } + gl.attachShader(this.program!, shader); + return shader; + } + + private setupShaders(): void { + const gl = this.gl!; + this.program = + assertNotNull(gl.createProgram()!, 'Failed to create WebGL program'); + + this.vertexShader = this.compileShader(VERTEX_SHADER, gl.VERTEX_SHADER); + this.fragmentShader = + this.compileShader(FRAGMENT_SHADER, gl.FRAGMENT_SHADER); + + gl.linkProgram(this.program); + const linked = gl.getProgramParameter(this.program, gl.LINK_STATUS); + if (!linked) { + const info = gl.getProgramInfoLog(this.program); + throw new Error(`Error during program linking: ${info}`); + } + + this.aVertex = gl.getAttribLocation(this.program, 'aVertex'); + this.aTex = gl.getAttribLocation(this.program, 'aTex'); + } + + private createBuffers(flipVertically: boolean): MPImageShaderBuffers { + const gl = this.gl!; + const vertexArrayObject = + assertNotNull(gl.createVertexArray(), 'Failed to create vertex array'); + gl.bindVertexArray(vertexArrayObject); + + const vertexBuffer = + assertNotNull(gl.createBuffer(), 'Failed to create buffer'); + gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer); + gl.enableVertexAttribArray(this.aVertex!); + gl.vertexAttribPointer(this.aVertex!, 2, gl.FLOAT, false, 0, 0); + gl.bufferData( + gl.ARRAY_BUFFER, new Float32Array([-1, -1, -1, 1, 1, 1, 1, -1]), + gl.STATIC_DRAW); + + const textureBuffer = + assertNotNull(gl.createBuffer(), 'Failed to create buffer'); + gl.bindBuffer(gl.ARRAY_BUFFER, textureBuffer); + gl.enableVertexAttribArray(this.aTex!); + gl.vertexAttribPointer(this.aTex!, 2, gl.FLOAT, false, 0, 0); + + const bufferData = + flipVertically ? [0, 1, 0, 0, 1, 0, 1, 1] : [0, 0, 0, 1, 1, 1, 1, 0]; + gl.bufferData( + gl.ARRAY_BUFFER, new Float32Array(bufferData), gl.STATIC_DRAW); + + gl.bindBuffer(gl.ARRAY_BUFFER, null); + gl.bindVertexArray(null); + + return new MPImageShaderBuffers( + gl, vertexArrayObject, vertexBuffer, textureBuffer); + } + + private getShaderBuffers(flipVertically: boolean): MPImageShaderBuffers { + if (flipVertically) { + if (!this.shaderBuffersFlipVertically) { + this.shaderBuffersFlipVertically = + this.createBuffers(/* flipVertically= */ true); + } + return this.shaderBuffersFlipVertically; + } else { + if (!this.shaderBuffersPassthrough) { + this.shaderBuffersPassthrough = + this.createBuffers(/* flipVertically= */ false); + } + return this.shaderBuffersPassthrough; + } + } + + private maybeInitGL(gl: WebGL2RenderingContext): void { + if (!this.gl) { + this.gl = gl; + } else if (gl !== this.gl) { + throw new Error('Cannot change GL context once initialized'); + } + } + + /** Runs the callback using the shader. */ + run( + gl: WebGL2RenderingContext, flipVertically: boolean, + callback: () => T): T { + this.maybeInitGL(gl); + + if (!this.program) { + this.setupShaders(); + } + + const shaderBuffers = this.getShaderBuffers(flipVertically); + gl.useProgram(this.program!); + shaderBuffers.bind(); + const result = callback(); + shaderBuffers.unbind(); + + return result; + } + /** + * Binds a framebuffer to the canvas. If the framebuffer does not yet exist, + * creates it first. Binds the provided texture to the framebuffer. + */ + bindFramebuffer(gl: WebGL2RenderingContext, texture: WebGLTexture): void { + this.maybeInitGL(gl); + if (!this.framebuffer) { + this.framebuffer = + assertNotNull(gl.createFramebuffer(), 'Failed to create framebuffe.'); + } + gl.bindFramebuffer(gl.FRAMEBUFFER, this.framebuffer); + gl.framebufferTexture2D( + gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); + } + + unbindFramebuffer(): void { + this.gl?.bindFramebuffer(this.gl.FRAMEBUFFER, null); + } + + close() { + if (this.program) { + const gl = this.gl!; + gl.deleteProgram(this.program); + gl.deleteShader(this.vertexShader!); + gl.deleteShader(this.fragmentShader!); + } + if (this.framebuffer) { + this.gl!.deleteFramebuffer(this.framebuffer); + } + if (this.shaderBuffersPassthrough) { + this.shaderBuffersPassthrough.close(); + } + if (this.shaderBuffersFlipVertically) { + this.shaderBuffersFlipVertically.close(); + } + } +} + +/** + * The wrapper class for MediaPipe Image objects. + * + * Images are stored as `ImageData`, `ImageBitmap` or `WebGLTexture` objects. + * You can convert the underlying type to any other type by passing the + * desired type to `getImage()`. As type conversions can be expensive, it is + * recommended to limit these conversions. You can verify what underlying + * types are already available by invoking `hasType()`. + * + * Images that are returned from a MediaPipe Tasks are owned by by the + * underlying C++ Task. If you need to extend the lifetime of these objects, + * you can invoke the `clone()` method. To free up the resources obtained + * during any clone or type conversion operation, it is important to invoke + * `close()` on the `MPImage` instance. + * + * Converting to and from ImageBitmap requires that the MediaPipe task is + * initialized with an `OffscreenCanvas`. As we require WebGL2 support, this + * places some limitations on Browser support as outlined here: + * https://developer.mozilla.org/en-US/docs/Web/API/OffscreenCanvas/getContext + */ +export class MPImage { + private gl?: WebGL2RenderingContext; + + /** @hideconstructor */ + constructor( + private imageData: ImageData|null, + private imageBitmap: ImageBitmap|null, + private webGLTexture: WebGLTexture|null, + private ownsImageBitmap: boolean, + private ownsWebGLTexture: boolean, + /** Returns the canvas element that the image is bound to. */ + readonly canvas: HTMLCanvasElement|OffscreenCanvas|undefined, + private shaderContext: MPImageShaderContext|undefined, + /** Returns the width of the image. */ + readonly width: number, + /** Returns the height of the image. */ + readonly height: number, + ) {} + + /** + * Returns whether this `MPImage` stores the image in the desired format. + * This method can be called to reduce expensive conversion before invoking + * `getType()`. + */ + hasType(type: MPImageStorageType): boolean { + if (type === MPImageStorageType.IMAGE_DATA) { + return !!this.imageData; + } else if (type === MPImageStorageType.IMAGE_BITMAP) { + return !!this.imageBitmap; + } else if (type === MPImageStorageType.WEBGL_TEXTURE) { + return !!this.webGLTexture; + } else { + throw new Error(`Type is not supported: ${type}`); + } + } + + /** + * Returns the underlying image as an `ImageData` object. Note that this + * involves an expensive GPU to CPU transfer if the current image is only + * available as an `ImageBitmap` or `WebGLTexture`. + * + * @return The current image as an ImageData object. + */ + getImage(type: MPImageStorageType.IMAGE_DATA): ImageData; + /** + * Returns the underlying image as an `ImageBitmap`. Note that + * conversions to `ImageBitmap` are expensive, especially if the data + * currently resides on CPU. + * + * Processing with `ImageBitmap`s requires that the MediaPipe Task was + * initialized with an `OffscreenCanvas` with WebGL2 support. See + * https://developer.mozilla.org/en-US/docs/Web/API/OffscreenCanvas/getContext + * for a list of supported platforms. + * + * @return The current image as an ImageBitmap object. + */ + getImage(type: MPImageStorageType.IMAGE_BITMAP): ImageBitmap; + /** + * Returns the underlying image as a `WebGLTexture` object. Note that this + * involves a CPU to GPU transfer if the current image is only available as + * an `ImageData` object. The returned texture is bound to the current + * canvas (see `.canvas`). + * + * @return The current image as a WebGLTexture. + */ + getImage(type: MPImageStorageType.WEBGL_TEXTURE): WebGLTexture; + getImage(type?: MPImageStorageType): MPImageNativeContainer { + if (type === MPImageStorageType.IMAGE_DATA) { + return this.convertToImageData(); + } else if (type === MPImageStorageType.IMAGE_BITMAP) { + return this.convertToImageBitmap(); + } else if (type === MPImageStorageType.WEBGL_TEXTURE) { + return this.convertToWebGLTexture(); + } else { + throw new Error(`Type is not supported: ${type}`); + } + } + + /** + * Creates a copy of the resources stored in this `MPImage`. You can invoke + * this method to extend the lifetime of an image returned by a MediaPipe + * Task. Note that performance critical applications should aim to only use + * the `MPImage` within the MediaPipe Task callback so that copies can be + * avoided. + */ + clone(): MPImage { + // TODO: We might only want to clone one backing datastructure + // even if multiple are defined. + let destinationImageData: ImageData|null = null; + let destinationImageBitmap: ImageBitmap|null = null; + let destinationWebGLTexture: WebGLTexture|null = null; + + if (this.imageData) { + destinationImageData = + new ImageData(this.imageData.data, this.width, this.height); + } + + if (this.webGLTexture) { + const gl = this.getGL(); + const shaderContext = this.getShaderContext(); + + // Create a new texture and use it to back a framebuffer + gl.activeTexture(gl.TEXTURE1); + destinationWebGLTexture = + assertNotNull(gl.createTexture(), 'Failed to create texture'); + gl.bindTexture(gl.TEXTURE_2D, destinationWebGLTexture); + + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, this.width, this.height, 0, gl.RGBA, + gl.UNSIGNED_BYTE, null); + + shaderContext.bindFramebuffer(gl, destinationWebGLTexture); + shaderContext.run(gl, /* flipVertically= */ false, () => { + this.bindTexture(); // This activates gl.TEXTURE0 + gl.clearColor(0, 0, 0, 0); + gl.clear(gl.COLOR_BUFFER_BIT); + gl.drawArrays(gl.TRIANGLE_FAN, 0, 4); + this.unbindTexture(); + }); + shaderContext.unbindFramebuffer(); + + this.unbindTexture(); + } + + if (this.imageBitmap) { + this.convertToWebGLTexture(); + this.bindTexture(); + destinationImageBitmap = this.copyTextureToBitmap(); + this.unbindTexture(); + } + + return new MPImage( + destinationImageData, destinationImageBitmap, destinationWebGLTexture, + !!destinationImageBitmap, !!destinationWebGLTexture, this.canvas, + this.shaderContext, this.width, this.height); + } + + + private getOffscreenCanvas(): OffscreenCanvas { + if (!(this.canvas instanceof OffscreenCanvas)) { + throw new Error( + 'Conversion to ImageBitmap requires that the MediaPipe Tasks is ' + + 'initialized with an OffscreenCanvas'); + } + return this.canvas; + } + + private getGL(): WebGL2RenderingContext { + if (!this.canvas) { + throw new Error( + 'Conversion to different image formats require that a canvas ' + + 'is passed when iniitializing the image.'); + } + if (!this.gl) { + this.gl = assertNotNull( + this.canvas.getContext('webgl2') as WebGL2RenderingContext | null, + 'You cannot use a canvas that is already bound to a different ' + + 'type of rendering context.'); + } + return this.gl; + } + + private getShaderContext(): MPImageShaderContext { + if (!this.shaderContext) { + this.shaderContext = new MPImageShaderContext(); + } + return this.shaderContext; + } + + private convertToImageBitmap(): ImageBitmap { + if (!this.imageBitmap) { + if (!this.webGLTexture) { + this.webGLTexture = this.convertToWebGLTexture(); + } + this.imageBitmap = this.convertWebGLTextureToImageBitmap(); + this.ownsImageBitmap = true; + } + + return this.imageBitmap; + } + + private convertToImageData(): ImageData { + if (!this.imageData) { + const gl = this.getGL(); + const shaderContext = this.getShaderContext(); + const pixels = new Uint8Array(this.width * this.height * 4); + + // Create texture if needed + this.convertToWebGLTexture(); + + // Create a framebuffer from the texture and read back pixels + shaderContext.bindFramebuffer(gl, this.webGLTexture!); + gl.readPixels( + 0, 0, this.width, this.height, gl.RGBA, gl.UNSIGNED_BYTE, pixels); + shaderContext.unbindFramebuffer(); + + this.imageData = new ImageData( + new Uint8ClampedArray(pixels.buffer), this.width, this.height); + } + + return this.imageData; + } + + private convertToWebGLTexture(): WebGLTexture { + if (!this.webGLTexture) { + const gl = this.getGL(); + this.bindTexture(); + const source = (this.imageBitmap || this.imageData)!; + gl.texImage2D( + gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, source); + this.unbindTexture(); + } + + return this.webGLTexture!; + } + + /** + * Binds the backing texture to the canvas. If the texture does not yet + * exist, creates it first. + */ + private bindTexture() { + const gl = this.getGL(); + + gl.viewport(0, 0, this.width, this.height); + + gl.activeTexture(gl.TEXTURE0); + if (!this.webGLTexture) { + this.webGLTexture = + assertNotNull(gl.createTexture(), 'Failed to create texture'); + this.ownsWebGLTexture = true; + } + + gl.bindTexture(gl.TEXTURE_2D, this.webGLTexture); + // TODO: Ideally, we would only set these once per texture and + // not once every frame. + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR); + gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.LINEAR); + } + + private unbindTexture(): void { + this.gl!.bindTexture(this.gl!.TEXTURE_2D, null); + } + + /** + * Invokes a shader to render the current texture and return it as an + * ImageBitmap + */ + private copyTextureToBitmap(): ImageBitmap { + const gl = this.getGL(); + const shaderContext = this.getShaderContext(); + + return shaderContext.run(gl, /* flipVertically= */ true, () => { + return this.runWithResizedCanvas(() => { + // Unbind any framebuffer that may be bound since + // `transferToImageBitmap()` requires rendering into the display (null) + // framebuffer. + gl.bindFramebuffer(gl.FRAMEBUFFER, null); + + gl.clearColor(0, 0, 0, 0); + gl.clear(gl.COLOR_BUFFER_BIT); + gl.drawArrays(gl.TRIANGLE_FAN, 0, 4); + return this.getOffscreenCanvas().transferToImageBitmap(); + }); + }); + } + + private convertWebGLTextureToImageBitmap(): ImageBitmap { + this.bindTexture(); + const result = this.copyTextureToBitmap(); + this.unbindTexture(); + return result; + } + + /** + * Temporarily resizes the underlying canvas to match the dimensions of the + * image. Runs the provided callback on the resized canvas. + * + * Note that while resizing is an expensive operation, it allows us to use + * the synchronous `transferToImageBitmap()` API. + */ + private runWithResizedCanvas(callback: () => T): T { + const canvas = this.canvas!; + + if (canvas.width === this.width && canvas.height === this.height) { + return callback(); + } + + const originalWidth = canvas.width; + const originalHeight = canvas.height; + canvas.width = this.width; + canvas.height = this.height; + + const result = callback(); + + canvas.width = originalWidth; + canvas.height = originalHeight; + + return result; + } + + /** + * Frees up any resources owned by this `MPImage` instance. + * + * Note that this method does not free images that are owned by the C++ + * Task, as these are freed automatically once you leave the MediaPipe + * callback. Additionally, some shared state is freed only once you invoke the + * Task's `close()` method. + */ + close(): void { + if (this.ownsImageBitmap) { + this.imageBitmap!.close(); + } + + if (!this.gl) { + return; + } + + if (this.ownsWebGLTexture) { + this.gl.deleteTexture(this.webGLTexture!); + } + } +} diff --git a/mediapipe/tasks/web/vision/core/render_utils.ts b/mediapipe/tasks/web/vision/core/render_utils.ts index 879e23010..903d789f5 100644 --- a/mediapipe/tasks/web/vision/core/render_utils.ts +++ b/mediapipe/tasks/web/vision/core/render_utils.ts @@ -59,13 +59,12 @@ export function drawCategoryMask( const isFloatArray = image instanceof Float32Array; for (let i = 0; i < image.length; i++) { const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; - const color = COLOR_MAP[colorIndex]; + let color = COLOR_MAP[colorIndex % COLOR_MAP.length]; - // When we're given a confidence mask by accident, we just log and return. - // TODO: We should fix this. if (!color) { + // TODO: We should fix this. console.warn('No color for ', colorIndex); - return; + color = COLOR_MAP[colorIndex % COLOR_MAP.length]; } rgbaArray[4 * i] = color[0]; diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts index 5699126b9..344d4db85 100644 --- a/mediapipe/tasks/web/vision/core/types.d.ts +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -25,17 +25,6 @@ import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/ke */ export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture; -/** - * A callback that receives the computed masks from the segmentation tasks. The - * callback either receives a single element array with a category mask (as a - * `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`). - * The returned data is only valid for the duration of the callback. If - * asynchronous processing is needed, all data needs to be copied before the - * callback returns. - */ -export type SegmentationMaskCallback = - (masks: SegmentationMask[], width: number, height: number) => void; - /** * A callback that receives an `ImageData` object from a Vision task. The * lifetime of the underlying data is limited to the duration of the callback. diff --git a/mediapipe/tasks/web/vision/face_landmarker/BUILD b/mediapipe/tasks/web/vision/face_landmarker/BUILD index 19108be3a..01f26bdad 100644 --- a/mediapipe/tasks/web/vision/face_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/face_landmarker/BUILD @@ -31,6 +31,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/components/containers:landmark", "//mediapipe/tasks/web/components/containers:matrix", "//mediapipe/tasks/web/components/processors:classifier_result", + "//mediapipe/tasks/web/components/processors:landmark_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", @@ -73,9 +74,9 @@ mediapipe_ts_library( ":face_landmarker_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", - "//mediapipe/framework/formats:landmark_jspb_proto", "//mediapipe/framework/formats:matrix_data_jspb_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_jspb_proto", + "//mediapipe/tasks/web/components/processors:landmark_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/vision/core:vision_task_runner", diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker.ts index 2e6ec5d10..2a30f0606 100644 --- a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker.ts +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker.ts @@ -23,8 +23,8 @@ import {FaceDetectorGraphOptions} from '../../../../tasks/cc/vision/face_detecto import {FaceGeometry as FaceGeometryProto} from '../../../../tasks/cc/vision/face_geometry/proto/face_geometry_pb'; import {FaceLandmarkerGraphOptions} from '../../../../tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options_pb'; import {FaceLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options_pb'; -import {NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; import {convertFromClassifications} from '../../../../tasks/web/components/processors/classifier_result'; +import {convertToLandmarks} from '../../../../tasks/web/components/processors/landmark_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; @@ -243,15 +243,7 @@ export class FaceLandmarker extends VisionTaskRunner { for (const binaryProto of data) { const faceLandmarksProto = NormalizedLandmarkListProto.deserializeBinary(binaryProto); - const landmarks: NormalizedLandmark[] = []; - for (const faceLandmarkProto of faceLandmarksProto.getLandmarkList()) { - landmarks.push({ - x: faceLandmarkProto.getX() ?? 0, - y: faceLandmarkProto.getY() ?? 0, - z: faceLandmarkProto.getZ() ?? 0, - }); - } - this.result.faceLandmarks.push(landmarks); + this.result.faceLandmarks.push(convertToLandmarks(faceLandmarksProto)); } } diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_test.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_test.ts index 92012a6f3..b590b4a4a 100644 --- a/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarker_test.ts @@ -17,9 +17,9 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; -import {NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; import {MatrixData as MatrixDataProto} from '../../../../framework/formats/matrix_data_pb'; import {FaceGeometry as FaceGeometryProto} from '../../../../tasks/cc/vision/face_geometry/proto/face_geometry_pb'; +import {createLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; @@ -31,7 +31,7 @@ import {FaceLandmarkerOptions} from './face_landmarker_options'; type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); -function createBlendshapes(): Uint8Array[] { +function createBlendshapes(): ClassificationList { const blendshapesProto = new ClassificationList(); const classification = new Classification(); classification.setScore(0.1); @@ -39,27 +39,17 @@ function createBlendshapes(): Uint8Array[] { classification.setLabel('face_label'); classification.setDisplayName('face_display_name'); blendshapesProto.addClassification(classification); - return [blendshapesProto.serializeBinary()]; + return blendshapesProto; } -function createFacialTransformationMatrixes(): Uint8Array[] { +function createFacialTransformationMatrixes(): FaceGeometryProto { const faceGeometryProto = new FaceGeometryProto(); const posteTransformationMatrix = new MatrixDataProto(); posteTransformationMatrix.setRows(1); posteTransformationMatrix.setCols(1); posteTransformationMatrix.setPackedDataList([1.0]); faceGeometryProto.setPoseTransformMatrix(posteTransformationMatrix); - return [faceGeometryProto.serializeBinary()]; -} - -function createLandmarks(): Uint8Array[] { - const faceLandmarksProto = new NormalizedLandmarkList(); - const landmark = new NormalizedLandmark(); - landmark.setX(0.3); - landmark.setY(0.4); - landmark.setZ(0.5); - faceLandmarksProto.addLandmark(landmark); - return [faceLandmarksProto.serializeBinary()]; + return faceGeometryProto; } class FaceLandmarkerFake extends FaceLandmarker implements MediapipeTasksFake { @@ -243,13 +233,17 @@ describe('FaceLandmarker', () => { }); it('transforms results', async () => { + const landmarksProto = [createLandmarks().serializeBinary()]; + const blendshapesProto = [createBlendshapes().serializeBinary()]; + const faceGeometryProto = + [createFacialTransformationMatrixes().serializeBinary()]; + // Pass the test data to our listener faceLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(faceLandmarker); - faceLandmarker.listeners.get('face_landmarks')!(createLandmarks(), 1337); - faceLandmarker.listeners.get('blendshapes')!(createBlendshapes(), 1337); - faceLandmarker.listeners.get('face_geometry')! - (createFacialTransformationMatrixes(), 1337); + faceLandmarker.listeners.get('face_landmarks')!(landmarksProto, 1337); + faceLandmarker.listeners.get('blendshapes')!(blendshapesProto, 1337); + faceLandmarker.listeners.get('face_geometry')!(faceGeometryProto, 1337); }); await faceLandmarker.setOptions({ @@ -266,7 +260,7 @@ describe('FaceLandmarker', () => { expect(faceLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(landmarks).toEqual({ - faceLandmarks: [[{x: 0.3, y: 0.4, z: 0.5}]], + faceLandmarks: [[{x: 0, y: 0, z: 0}]], faceBlendshapes: [{ categories: [{ index: 1, @@ -282,12 +276,16 @@ describe('FaceLandmarker', () => { }); it('clears results between invoations', async () => { + const landmarksProto = [createLandmarks().serializeBinary()]; + const blendshapesProto = [createBlendshapes().serializeBinary()]; + const faceGeometryProto = + [createFacialTransformationMatrixes().serializeBinary()]; + // Pass the test data to our listener faceLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { - faceLandmarker.listeners.get('face_landmarks')!(createLandmarks(), 1337); - faceLandmarker.listeners.get('blendshapes')!(createBlendshapes(), 1337); - faceLandmarker.listeners.get('face_geometry')! - (createFacialTransformationMatrixes(), 1337); + faceLandmarker.listeners.get('face_landmarks')!(landmarksProto, 1337); + faceLandmarker.listeners.get('blendshapes')!(blendshapesProto, 1337); + faceLandmarker.listeners.get('face_geometry')!(faceGeometryProto, 1337); }); await faceLandmarker.setOptions({ diff --git a/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts b/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts index 978324750..337f663e3 100644 --- a/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts +++ b/mediapipe/tasks/web/vision/face_landmarker/face_landmarks_connections.ts @@ -19,7 +19,7 @@ import {Connection} from '../../../../tasks/web/vision/core/types'; // tslint:disable:class-as-namespace Using for easier import by 3P users /** - * A class containing the Pairs of landmark indices to be rendered with + * A class containing the pairs of landmark indices to be rendered with * connections. */ export class FaceLandmarksConnections { diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts index 34067aaba..dfce03030 100644 --- a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts @@ -129,10 +129,6 @@ export class FaceStylizer extends VisionTaskRunner { * synchronously once the callback returns. Only use this method when the * FaceStylizer is created with the image running mode. * - * The input image can be of any size. To ensure that the output image has - * reasonable quality, the stylized output image size is determined by the - * model output size. - * * @param image An image to process. * @param callback The callback that is invoked with the stylized image. The * lifetime of the returned data is only guaranteed for the duration of the @@ -153,11 +149,6 @@ export class FaceStylizer extends VisionTaskRunner { * If both are specified, the crop around the region-of-interest is extracted * first, then the specified rotation is applied to the crop. * - * The input image can be of any size. To ensure that the output image has - * reasonable quality, the stylized output image size is the smaller of the - * model output size and the size of the 'regionOfInterest' specified in - * 'imageProcessingOptions'. - * * @param image An image to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. @@ -192,9 +183,6 @@ export class FaceStylizer extends VisionTaskRunner { * frame's timestamp (in milliseconds). The input timestamps must be * monotonically increasing. * - * To ensure that the output image has reasonable quality, the stylized - * output image size is determined by the model output size. - * * @param videoFrame A video frame to process. * @param timestamp The timestamp of the current frame, in ms. * @param callback The callback that is invoked with the stylized image. The @@ -221,10 +209,6 @@ export class FaceStylizer extends VisionTaskRunner { * frame's timestamp (in milliseconds). The input timestamps must be * monotonically increasing. * - * To ensure that the output image has reasonable quality, the stylized - * output image size is the smaller of the model output size and the size of - * the 'regionOfInterest' specified in 'imageProcessingOptions'. - * * @param videoFrame A video frame to process. * @param imageProcessingOptions the `ImageProcessingOptions` specifying how * to process the input image before running inference. @@ -278,8 +262,12 @@ export class FaceStylizer extends VisionTaskRunner { this.graphRunner.attachImageListener( STYLIZED_IMAGE_STREAM, (image, timestamp) => { - const imageData = this.convertToImageData(image); - this.userCallback(imageData, image.width, image.height); + if (image.data instanceof WebGLTexture) { + this.userCallback(image.data, image.width, image.height); + } else { + const imageData = this.convertToImageData(image); + this.userCallback(imageData, image.width, image.height); + } this.setLatestOutputTimestamp(timestamp); }); this.graphRunner.attachEmptyPacketListener( diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD index 9156e89b7..a3a630e90 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/BUILD +++ b/mediapipe/tasks/web/vision/gesture_recognizer/BUILD @@ -34,6 +34,7 @@ mediapipe_ts_library( "//mediapipe/tasks/web/core:classifier_options", "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/tasks/web/vision/hand_landmarker:hand_landmarks_connections", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts index 74d37cb63..df9c91282 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer.ts @@ -31,6 +31,7 @@ import {convertClassifierOptionsToProto} from '../../../../tasks/web/components/ import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -72,6 +73,12 @@ export class GestureRecognizer extends VisionTaskRunner { private readonly handGestureRecognizerGraphOptions: HandGestureRecognizerGraphOptions; + /** + * An array containing the pairs of hand landmark indices to be rendered with + * connections. + */ + static HAND_CONNECTIONS = HAND_CONNECTIONS; + /** * Initializes the Wasm runtime and creates a new gesture recognizer from the * provided options. diff --git a/mediapipe/tasks/web/vision/hand_landmarker/BUILD b/mediapipe/tasks/web/vision/hand_landmarker/BUILD index c5687ee2f..cd6f39f7d 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/BUILD +++ b/mediapipe/tasks/web/vision/hand_landmarker/BUILD @@ -16,6 +16,7 @@ mediapipe_ts_library( visibility = ["//visibility:public"], deps = [ ":hand_landmarker_types", + ":hand_landmarks_connections", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", @@ -26,6 +27,7 @@ mediapipe_ts_library( "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_jspb_proto", "//mediapipe/tasks/web/components/containers:category", "//mediapipe/tasks/web/components/containers:landmark", + "//mediapipe/tasks/web/components/processors:landmark_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:vision_task_runner", @@ -60,7 +62,7 @@ mediapipe_ts_library( ":hand_landmarker_types", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework/formats:classification_jspb_proto", - "//mediapipe/framework/formats:landmark_jspb_proto", + "//mediapipe/tasks/web/components/processors:landmark_result", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:task_runner_test_utils", "//mediapipe/tasks/web/vision/core:vision_task_runner", @@ -72,3 +74,9 @@ jasmine_node_test( tags = ["nomsan"], deps = [":hand_landmarker_test_lib"], ) + +mediapipe_ts_library( + name = "hand_landmarks_connections", + srcs = ["hand_landmarks_connections.ts"], + deps = ["//mediapipe/tasks/web/vision/core:types"], +) diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts index 1978bb061..2d2d05f9f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker.ts @@ -24,9 +24,11 @@ import {HandLandmarkerGraphOptions} from '../../../../tasks/cc/vision/hand_landm import {HandLandmarksDetectorGraphOptions} from '../../../../tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options_pb'; import {Category} from '../../../../tasks/web/components/containers/category'; import {Landmark, NormalizedLandmark} from '../../../../tasks/web/components/containers/landmark'; +import {convertToLandmarks, convertToWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {HAND_CONNECTIONS} from '../../../../tasks/web/vision/hand_landmarker/hand_landmarks_connections'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url @@ -63,6 +65,12 @@ export class HandLandmarker extends VisionTaskRunner { HandLandmarksDetectorGraphOptions; private readonly handDetectorGraphOptions: HandDetectorGraphOptions; + /** + * An array containing the pairs of hand landmark indices to be rendered with + * connections. + */ + static HAND_CONNECTIONS = HAND_CONNECTIONS; + /** * Initializes the Wasm runtime and creates a new `HandLandmarker` from the * provided options. @@ -252,15 +260,7 @@ export class HandLandmarker extends VisionTaskRunner { for (const binaryProto of data) { const handLandmarksProto = NormalizedLandmarkList.deserializeBinary(binaryProto); - const landmarks: NormalizedLandmark[] = []; - for (const handLandmarkProto of handLandmarksProto.getLandmarkList()) { - landmarks.push({ - x: handLandmarkProto.getX() ?? 0, - y: handLandmarkProto.getY() ?? 0, - z: handLandmarkProto.getZ() ?? 0, - }); - } - this.landmarks.push(landmarks); + this.landmarks.push(convertToLandmarks(handLandmarksProto)); } } @@ -272,16 +272,8 @@ export class HandLandmarker extends VisionTaskRunner { for (const binaryProto of data) { const handWorldLandmarksProto = LandmarkList.deserializeBinary(binaryProto); - const worldLandmarks: Landmark[] = []; - for (const handWorldLandmarkProto of - handWorldLandmarksProto.getLandmarkList()) { - worldLandmarks.push({ - x: handWorldLandmarkProto.getX() ?? 0, - y: handWorldLandmarkProto.getY() ?? 0, - z: handWorldLandmarkProto.getZ() ?? 0, - }); - } - this.worldLandmarks.push(worldLandmarks); + this.worldLandmarks.push( + convertToWorldLandmarks(handWorldLandmarksProto)); } } diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts index 8a6d9bfa6..dc0f2fe0f 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_result.d.ts @@ -26,7 +26,7 @@ export declare interface HandLandmarkerResult { /** Hand landmarks of detected hands. */ landmarks: NormalizedLandmark[][]; - /** Hand landmarks in world coordniates of detected hands. */ + /** Hand landmarks in world coordinates of detected hands. */ worldLandmarks: Landmark[][]; /** Handedness of detected hands. */ diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts index 5fd493424..f439e66e6 100644 --- a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarker_test.ts @@ -17,7 +17,7 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {Classification, ClassificationList} from '../../../../framework/formats/classification_pb'; -import {Landmark, LandmarkList, NormalizedLandmark, NormalizedLandmarkList} from '../../../../framework/formats/landmark_pb'; +import {createLandmarks, createWorldLandmarks} from '../../../../tasks/web/components/processors/landmark_result_test_lib'; import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; import {VisionGraphRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; @@ -30,7 +30,7 @@ import {HandLandmarkerOptions} from './hand_landmarker_options'; type ProtoListener = ((binaryProtos: Uint8Array[], timestamp: number) => void); -function createHandednesses(): Uint8Array[] { +function createHandednesses(): ClassificationList { const handsProto = new ClassificationList(); const classification = new Classification(); classification.setScore(0.1); @@ -38,27 +38,7 @@ function createHandednesses(): Uint8Array[] { classification.setLabel('handedness_label'); classification.setDisplayName('handedness_display_name'); handsProto.addClassification(classification); - return [handsProto.serializeBinary()]; -} - -function createLandmarks(): Uint8Array[] { - const handLandmarksProto = new NormalizedLandmarkList(); - const landmark = new NormalizedLandmark(); - landmark.setX(0.3); - landmark.setY(0.4); - landmark.setZ(0.5); - handLandmarksProto.addLandmark(landmark); - return [handLandmarksProto.serializeBinary()]; -} - -function createWorldLandmarks(): Uint8Array[] { - const handLandmarksProto = new LandmarkList(); - const landmark = new Landmark(); - landmark.setX(21); - landmark.setY(22); - landmark.setZ(23); - handLandmarksProto.addLandmark(landmark); - return [handLandmarksProto.serializeBinary()]; + return handsProto; } class HandLandmarkerFake extends HandLandmarker implements MediapipeTasksFake { @@ -212,13 +192,17 @@ describe('HandLandmarker', () => { }); it('transforms results', async () => { + const landmarksProto = [createLandmarks().serializeBinary()]; + const worldLandmarksProto = [createWorldLandmarks().serializeBinary()]; + const handednessProto = [createHandednesses().serializeBinary()]; + // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(handLandmarker); - handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); + handLandmarker.listeners.get('hand_landmarks')!(landmarksProto, 1337); handLandmarker.listeners.get('world_hand_landmarks')! - (createWorldLandmarks(), 1337); - handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); + (worldLandmarksProto, 1337); + handLandmarker.listeners.get('handedness')!(handednessProto, 1337); }); // Invoke the hand landmarker @@ -230,8 +214,8 @@ describe('HandLandmarker', () => { expect(handLandmarker.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); expect(landmarks).toEqual({ - 'landmarks': [[{'x': 0.3, 'y': 0.4, 'z': 0.5}]], - 'worldLandmarks': [[{'x': 21, 'y': 22, 'z': 23}]], + 'landmarks': [[{'x': 0, 'y': 0, 'z': 0}]], + 'worldLandmarks': [[{'x': 0, 'y': 0, 'z': 0}]], 'handednesses': [[{ 'score': 0.1, 'index': 1, @@ -242,12 +226,16 @@ describe('HandLandmarker', () => { }); it('clears results between invoations', async () => { + const landmarks = [createLandmarks().serializeBinary()]; + const worldLandmarks = [createWorldLandmarks().serializeBinary()]; + const handedness = [createHandednesses().serializeBinary()]; + // Pass the test data to our listener handLandmarker.fakeWasmModule._waitUntilIdle.and.callFake(() => { - handLandmarker.listeners.get('hand_landmarks')!(createLandmarks(), 1337); + handLandmarker.listeners.get('hand_landmarks')!(landmarks, 1337); handLandmarker.listeners.get('world_hand_landmarks')! - (createWorldLandmarks(), 1337); - handLandmarker.listeners.get('handedness')!(createHandednesses(), 1337); + (worldLandmarks, 1337); + handLandmarker.listeners.get('handedness')!(handedness, 1337); }); // Invoke the hand landmarker twice diff --git a/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarks_connections.ts b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarks_connections.ts new file mode 100644 index 000000000..edb789c8f --- /dev/null +++ b/mediapipe/tasks/web/vision/hand_landmarker/hand_landmarks_connections.ts @@ -0,0 +1,31 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {Connection} from '../../../../tasks/web/vision/core/types'; + +/** + * An array containing the pairs of hand landmark indices to be rendered with + * connections. + */ +export const HAND_CONNECTIONS: Connection[] = [ + {start: 0, end: 1}, {start: 1, end: 2}, {start: 2, end: 3}, + {start: 3, end: 4}, {start: 0, end: 5}, {start: 5, end: 6}, + {start: 6, end: 7}, {start: 7, end: 8}, {start: 5, end: 9}, + {start: 9, end: 10}, {start: 10, end: 11}, {start: 11, end: 12}, + {start: 9, end: 13}, {start: 13, end: 14}, {start: 14, end: 15}, + {start: 15, end: 16}, {start: 13, end: 17}, {start: 0, end: 17}, + {start: 17, end: 18}, {start: 18, end: 19}, {start: 19, end: 20} +]; diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD index a4b9008dd..3db15641f 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -29,7 +29,10 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "image_segmenter_types", - srcs = ["image_segmenter_options.d.ts"], + srcs = [ + "image_segmenter_options.d.ts", + "image_segmenter_result.d.ts", + ], deps = [ "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index 3690fd855..c32423e12 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -22,33 +22,48 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; -import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {SegmentationMask} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {LabelMapItem} from '../../../../util/label_map_pb'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageSegmenterOptions} from './image_segmenter_options'; +import {ImageSegmenterResult} from './image_segmenter_result'; export * from './image_segmenter_options'; -export {SegmentationMask, SegmentationMaskCallback}; +export * from './image_segmenter_result'; +export {SegmentationMask}; export {ImageSource}; // Used in the public API const IMAGE_STREAM = 'image_in'; const NORM_RECT_STREAM = 'norm_rect'; -const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; +const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; +const CATEGORY_MASK_STREAM = 'category_mask'; const IMAGE_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = 'mediapipe.tasks.TensorsToSegmentationCalculator'; +const DEFAULT_OUTPUT_CATEGORY_MASK = false; +const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern +/** + * A callback that receives the computed masks from the image segmenter. The + * returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type ImageSegmenterCallback = (result: ImageSegmenterResult) => void; + /** Performs image segmentation on images. */ export class ImageSegmenter extends VisionTaskRunner { - private userCallback: SegmentationMaskCallback = () => {}; + private result: ImageSegmenterResult = {width: 0, height: 0}; private labels: string[] = []; + private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; + private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -109,7 +124,6 @@ export class ImageSegmenter extends VisionTaskRunner { this.options.setBaseOptions(new BaseOptionsProto()); } - protected override get baseOptions(): BaseOptionsProto { return this.options.getBaseOptions()!; } @@ -137,12 +151,14 @@ export class ImageSegmenter extends VisionTaskRunner { this.options.clearDisplayNamesLocale(); } - if (options.outputType === 'CONFIDENCE_MASK') { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); - } else { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CATEGORY_MASK); + if ('outputCategoryMask' in options) { + this.outputCategoryMask = + options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK; + } + + if ('outputConfidenceMasks' in options) { + this.outputConfidenceMasks = + options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS; } return super.applyOptions(options); @@ -192,7 +208,7 @@ export class ImageSegmenter extends VisionTaskRunner { * lifetime of the returned data is only guaranteed for the duration of the * callback. */ - segment(image: ImageSource, callback: SegmentationMaskCallback): void; + segment(image: ImageSource, callback: ImageSegmenterCallback): void; /** * Performs image segmentation on the provided single image and invokes the * callback with the response. The method returns synchronously once the @@ -208,22 +224,77 @@ export class ImageSegmenter extends VisionTaskRunner { */ segment( image: ImageSource, imageProcessingOptions: ImageProcessingOptions, - callback: SegmentationMaskCallback): void; + callback: ImageSegmenterCallback): void; segment( image: ImageSource, imageProcessingOptionsOrCallback: ImageProcessingOptions| - SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { + ImageSegmenterCallback, + callback?: ImageSegmenterCallback): void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; - - this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + const userCallback = + typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : callback!; + + this.reset(); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + userCallback(this.result); + } + + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, timestamp: number, + callback: ImageSegmenterCallback): void; + /** + * Performs image segmentation on the provided video frame and invokes the + * callback with the response. The method returns synchronously once the + * callback returns. Only use this method when the ImageSegmenter is + * created with running mode `video`. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segmentForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number, callback: ImageSegmenterCallback): void; + segmentForVideo( + videoFrame: ImageSource, + timestampOrImageProcessingOptions: number|ImageProcessingOptions, + timestampOrCallback: number|ImageSegmenterCallback, + callback?: ImageSegmenterCallback): void { + const imageProcessingOptions = + typeof timestampOrImageProcessingOptions !== 'number' ? + timestampOrImageProcessingOptions : + {}; + const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? + timestampOrImageProcessingOptions : + timestampOrCallback as number; + const userCallback = typeof timestampOrCallback === 'function' ? + timestampOrCallback : + callback!; + + this.reset(); + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + userCallback(this.result); } /** @@ -241,56 +312,8 @@ export class ImageSegmenter extends VisionTaskRunner { return this.labels; } - /** - * Performs image segmentation on the provided video frame and invokes the - * callback with the response. The method returns synchronously once the - * callback returns. Only use this method when the ImageSegmenter is - * created with running mode `video`. - * - * @param videoFrame A video frame to process. - * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the segmented masks. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. - */ - segmentForVideo( - videoFrame: ImageSource, timestamp: number, - callback: SegmentationMaskCallback): void; - /** - * Performs image segmentation on the provided video frame and invokes the - * callback with the response. The method returns synchronously once the - * callback returns. Only use this method when the ImageSegmenter is - * created with running mode `video`. - * - * @param videoFrame A video frame to process. - * @param imageProcessingOptions the `ImageProcessingOptions` specifying how - * to process the input image before running inference. - * @param timestamp The timestamp of the current frame, in ms. - * @param callback The callback that is invoked with the segmented masks. The - * lifetime of the returned data is only guaranteed for the duration of the - * callback. - */ - segmentForVideo( - videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, - timestamp: number, callback: SegmentationMaskCallback): void; - segmentForVideo( - videoFrame: ImageSource, - timestampOrImageProcessingOptions: number|ImageProcessingOptions, - timestampOrCallback: number|SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { - const imageProcessingOptions = - typeof timestampOrImageProcessingOptions !== 'number' ? - timestampOrImageProcessingOptions : - {}; - const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? - timestampOrImageProcessingOptions : - timestampOrCallback as number; - - this.userCallback = typeof timestampOrCallback === 'function' ? - timestampOrCallback : - callback!; - this.processVideoData(videoFrame, imageProcessingOptions, timestamp); - this.userCallback = () => {}; + private reset(): void { + this.result = {width: 0, height: 0}; } /** Updates the MediaPipe graph configuration. */ @@ -298,7 +321,6 @@ export class ImageSegmenter extends VisionTaskRunner { const graphConfig = new CalculatorGraphConfig(); graphConfig.addInputStream(IMAGE_STREAM); graphConfig.addInputStream(NORM_RECT_STREAM); - graphConfig.addOutputStream(GROUPED_SEGMENTATIONS_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -308,26 +330,47 @@ export class ImageSegmenter extends VisionTaskRunner { segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH); segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); - segmenterNode.addOutputStream( - 'GROUPED_SEGMENTATION:' + GROUPED_SEGMENTATIONS_STREAM); segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); - this.graphRunner.attachImageVectorListener( - GROUPED_SEGMENTATIONS_STREAM, (masks, timestamp) => { - if (masks.length === 0) { - this.userCallback([], 0, 0); - } else { - this.userCallback( - masks.map(m => m.data), masks[0].width, masks[0].height); - } - this.setLatestOutputTimestamp(timestamp); - }); - this.graphRunner.attachEmptyPacketListener( - GROUPED_SEGMENTATIONS_STREAM, timestamp => { - this.setLatestOutputTimestamp(timestamp); - }); + if (this.outputConfidenceMasks) { + graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); + segmenterNode.addOutputStream( + 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + + this.graphRunner.attachImageVectorListener( + CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { + this.result.confidenceMasks = masks.map(m => m.data); + if (masks.length >= 0) { + this.result.width = masks[0].width; + this.result.height = masks[0].height; + } + + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CONFIDENCE_MASKS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + if (this.outputCategoryMask) { + graphConfig.addOutputStream(CATEGORY_MASK_STREAM); + segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + + this.graphRunner.attachImageListener( + CATEGORY_MASK_STREAM, (mask, timestamp) => { + this.result.categoryMask = mask.data; + this.result.width = mask.width; + this.result.height = mask.height; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CATEGORY_MASK_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts index c17e7e421..f80a792a5 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_options.d.ts @@ -24,18 +24,9 @@ export interface ImageSegmenterOptions extends VisionTaskOptions { */ displayNamesLocale?: string|undefined; - /** - * The output type of segmentation results. - * - * The two supported modes are: - * - Category Mask: Gives a single output mask where each pixel represents - * the class which the pixel in the original image was - * predicted to belong to. - * - Confidence Mask: Gives a list of output masks (one for each class). For - * each mask, the pixel represents the prediction - * confidence, usually in the [0.0, 0.1] range. - * - * Defaults to `CATEGORY_MASK`. - */ - outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; + /** Whether to output confidence masks. Defaults to true. */ + outputConfidenceMasks?: boolean|undefined; + + /** Whether to output the category masks. Defaults to false. */ + outputCategoryMask?: boolean|undefined; } diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts new file mode 100644 index 000000000..be082d516 --- /dev/null +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_result.d.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The output result of ImageSegmenter. */ +export declare interface ImageSegmenterResult { + /** + * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each + * pixel represents the prediction confidence, usually in the [0, 1] range. + */ + confidenceMasks?: Float32Array[]|WebGLTexture[]; + + /** + * A category mask as a Uint8ClampedArray or WebGLTexture where each + * pixel represents the class which the pixel in the original image was + * predicted to belong to. + */ + categoryMask?: Uint8ClampedArray|WebGLTexture; + + /** The width of the masks. */ + width: number; + + /** The height of the masks. */ + height: number; +} diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index 4cf27b9a5..6b5c90080 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -18,7 +18,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; -import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {ImageSegmenter} from './image_segmenter'; @@ -30,7 +30,9 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - imageVectorListener: + categoryMaskListener: + ((images: WasmImage, timestamp: number) => void)|undefined; + confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; constructor() { @@ -38,11 +40,16 @@ class ImageSegmenterFake extends ImageSegmenter implements MediapipeTasksFake { this.fakeWasmModule = this.graphRunner.wasmModule as unknown as SpyWasmModule; - this.attachListenerSpies[0] = + this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('category_mask'); + this.categoryMaskListener = listener; + }); + this.attachListenerSpies[1] = spyOn(this.graphRunner, 'attachImageVectorListener') .and.callFake((stream, listener) => { - expect(stream).toEqual('segmented_masks'); - this.imageVectorListener = listener; + expect(stream).toEqual('confidence_masks'); + this.confidenceMasksListener = listener; }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -63,17 +70,18 @@ describe('ImageSegmenter', () => { it('initializes graph', async () => { verifyGraph(imageSegmenter); - verifyListenersRegistered(imageSegmenter); + + // Verify default options + expect(imageSegmenter.categoryMaskListener).not.toBeDefined(); + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); }); it('reloads graph when settings are changed', async () => { await imageSegmenter.setOptions({displayNamesLocale: 'en'}); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); - verifyListenersRegistered(imageSegmenter); await imageSegmenter.setOptions({displayNamesLocale: 'de'}); verifyGraph(imageSegmenter, ['displayNamesLocale', 'de']); - verifyListenersRegistered(imageSegmenter); }); it('can use custom models', async () => { @@ -100,9 +108,11 @@ describe('ImageSegmenter', () => { }); it('merges options', async () => { - await imageSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + await imageSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); await imageSegmenter.setOptions({displayNamesLocale: 'en'}); - verifyGraph(imageSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyGraph( + imageSegmenter, [['baseOptions', 'modelAsset', 'fileContent'], '']); verifyGraph(imageSegmenter, ['displayNamesLocale', 'en']); }); @@ -115,22 +125,13 @@ describe('ImageSegmenter', () => { defaultValue: unknown; } - const testCases: TestCase[] = [ - { - optionName: 'displayNamesLocale', - fieldPath: ['displayNamesLocale'], - userValue: 'en', - graphValue: 'en', - defaultValue: 'en' - }, - { - optionName: 'outputType', - fieldPath: ['segmenterOptions', 'outputType'], - userValue: 'CONFIDENCE_MASK', - graphValue: 2, - defaultValue: 1 - }, - ]; + const testCases: TestCase[] = [{ + optionName: 'displayNamesLocale', + fieldPath: ['displayNamesLocale'], + userValue: 'en', + graphValue: 'en', + defaultValue: 'en' + }]; for (const testCase of testCases) { it(`can set ${testCase.optionName}`, async () => { @@ -158,27 +159,31 @@ describe('ImageSegmenter', () => { }).toThrowError('This task doesn\'t support region-of-interest.'); }); - it('supports category masks', (done) => { + it('supports category mask', async () => { const mask = new Uint8ClampedArray([1, 2, 3, 4]); + await imageSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: false}); + // Pass the test data to our listener imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(imageSegmenter); - imageSegmenter.imageVectorListener!( - [ - {data: mask, width: 2, height: 2}, - ], - /* timestamp= */ 1337); + expect(imageSegmenter.categoryMaskListener).toBeDefined(); + imageSegmenter.categoryMaskListener! + ({data: mask, width: 2, height: 2}, + /* timestamp= */ 1337); }); // Invoke the image segmenter - imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { - expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(masks).toHaveSize(1); - expect(masks[0]).toEqual(mask); - expect(width).toEqual(2); - expect(height).toEqual(2); - done(); + + return new Promise(resolve => { + imageSegmenter.segment({} as HTMLImageElement, result => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result.categoryMask).toEqual(mask); + expect(result.confidenceMasks).not.toBeDefined(); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); }); }); @@ -186,12 +191,13 @@ describe('ImageSegmenter', () => { const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); - await imageSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + await imageSegmenter.setOptions( + {outputCategoryMask: false, outputConfidenceMasks: true}); // Pass the test data to our listener imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(imageSegmenter); - imageSegmenter.imageVectorListener!( + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); + imageSegmenter.confidenceMasksListener!( [ {data: mask1, width: 2, height: 2}, {data: mask2, width: 2, height: 2}, @@ -201,13 +207,49 @@ describe('ImageSegmenter', () => { return new Promise(resolve => { // Invoke the image segmenter - imageSegmenter.segment({} as HTMLImageElement, (masks, width, height) => { + imageSegmenter.segment({} as HTMLImageElement, result => { expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); - expect(masks).toHaveSize(2); - expect(masks[0]).toEqual(mask1); - expect(masks[1]).toEqual(mask2); - expect(width).toEqual(2); - expect(height).toEqual(2); + expect(result.categoryMask).not.toBeDefined(); + expect(result.confidenceMasks).toEqual([mask1, mask2]); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); + }); + + it('supports combined category and confidence masks', async () => { + const categoryMask = new Uint8ClampedArray([1, 0]); + const confidenceMask1 = new Float32Array([0.0, 1.0]); + const confidenceMask2 = new Float32Array([1.0, 0.0]); + + await imageSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: true}); + + // Pass the test data to our listener + imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(imageSegmenter.categoryMaskListener).toBeDefined(); + expect(imageSegmenter.confidenceMasksListener).toBeDefined(); + imageSegmenter.categoryMaskListener! + ({data: categoryMask, width: 1, height: 1}, 1337); + imageSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask1, width: 1, height: 1}, + {data: confidenceMask2, width: 1, height: 1}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + imageSegmenter.segment({} as HTMLImageElement, result => { + expect(imageSegmenter.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(result.categoryMask).toEqual(categoryMask); + expect(result.confidenceMasks).toEqual([ + confidenceMask1, confidenceMask2 + ]); + expect(result.width).toEqual(1); + expect(result.height).toEqual(1); resolve(); }); }); diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 3b3757bbd..c4adab7e6 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -16,6 +16,7 @@ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; import {DrawingUtils as DrawingUtilsImpl} from '../../../tasks/web/vision/core/drawing_utils'; +import {MPImage as MPImageImpl} from '../../../tasks/web/vision/core/image'; import {FaceDetector as FaceDetectorImpl} from '../../../tasks/web/vision/face_detector/face_detector'; import {FaceLandmarker as FaceLandmarkerImpl, FaceLandmarksConnections as FaceLandmarksConnectionsImpl} from '../../../tasks/web/vision/face_landmarker/face_landmarker'; import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; @@ -31,6 +32,7 @@ import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/ob // as exports. const DrawingUtils = DrawingUtilsImpl; const FilesetResolver = FilesetResolverImpl; +const MPImage = MPImageImpl; const FaceDetector = FaceDetectorImpl; const FaceLandmarker = FaceLandmarkerImpl; const FaceLandmarksConnections = FaceLandmarksConnectionsImpl; @@ -46,6 +48,7 @@ const ObjectDetector = ObjectDetectorImpl; export { DrawingUtils, FilesetResolver, + MPImage, FaceDetector, FaceLandmarker, FaceLandmarksConnections, diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD index a4a3f27c9..ead85d38a 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD @@ -30,7 +30,10 @@ mediapipe_ts_library( mediapipe_ts_declaration( name = "interactive_segmenter_types", - srcs = ["interactive_segmenter_options.d.ts"], + srcs = [ + "interactive_segmenter_options.d.ts", + "interactive_segmenter_result.d.ts", + ], deps = [ "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core:classifier_options", diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts index ddcc7e592..df00b2cee 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -21,7 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; -import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {RegionOfInterest, SegmentationMask} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {Color as ColorProto} from '../../../../util/color_pb'; import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; @@ -29,21 +29,35 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner // Placeholder for internal dependency on trusted resource url import {InteractiveSegmenterOptions} from './interactive_segmenter_options'; +import {InteractiveSegmenterResult} from './interactive_segmenter_result'; export * from './interactive_segmenter_options'; -export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest}; +export * from './interactive_segmenter_result'; +export {SegmentationMask, RegionOfInterest}; export {ImageSource}; const IMAGE_IN_STREAM = 'image_in'; const NORM_RECT_IN_STREAM = 'norm_rect_in'; const ROI_IN_STREAM = 'roi_in'; -const IMAGE_OUT_STREAM = 'image_out'; +const CONFIDENCE_MASKS_STREAM = 'confidence_masks'; +const CATEGORY_MASK_STREAM = 'category_mask'; const IMAGEA_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; +const DEFAULT_OUTPUT_CATEGORY_MASK = false; +const DEFAULT_OUTPUT_CONFIDENCE_MASKS = true; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern +/** + * A callback that receives the computed masks from the interactive segmenter. + * The returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type InteractiveSegmenterCallback = + (result: InteractiveSegmenterResult) => void; + /** * Performs interactive segmentation on images. * @@ -69,7 +83,9 @@ const IMAGEA_SEGMENTER_GRAPH = * - batch is always 1 */ export class InteractiveSegmenter extends VisionTaskRunner { - private userCallback: SegmentationMaskCallback = () => {}; + private result: InteractiveSegmenterResult = {width: 0, height: 0}; + private outputCategoryMask = DEFAULT_OUTPUT_CATEGORY_MASK; + private outputConfidenceMasks = DEFAULT_OUTPUT_CONFIDENCE_MASKS; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -154,12 +170,14 @@ export class InteractiveSegmenter extends VisionTaskRunner { * @return A Promise that resolves when the settings have been applied. */ override setOptions(options: InteractiveSegmenterOptions): Promise { - if (options.outputType === 'CONFIDENCE_MASK') { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); - } else { - this.segmenterOptions.setOutputType( - SegmenterOptionsProto.OutputType.CATEGORY_MASK); + if ('outputCategoryMask' in options) { + this.outputCategoryMask = + options.outputCategoryMask ?? DEFAULT_OUTPUT_CATEGORY_MASK; + } + + if ('outputConfidenceMasks' in options) { + this.outputConfidenceMasks = + options.outputConfidenceMasks ?? DEFAULT_OUTPUT_CONFIDENCE_MASKS; } return super.applyOptions(options); @@ -184,7 +202,7 @@ export class InteractiveSegmenter extends VisionTaskRunner { */ segment( image: ImageSource, roi: RegionOfInterest, - callback: SegmentationMaskCallback): void; + callback: InteractiveSegmenterCallback): void; /** * Performs interactive segmentation on the provided single image and invokes * the callback with the response. The `roi` parameter is used to represent a @@ -213,24 +231,29 @@ export class InteractiveSegmenter extends VisionTaskRunner { segment( image: ImageSource, roi: RegionOfInterest, imageProcessingOptions: ImageProcessingOptions, - callback: SegmentationMaskCallback): void; + callback: InteractiveSegmenterCallback): void; segment( image: ImageSource, roi: RegionOfInterest, imageProcessingOptionsOrCallback: ImageProcessingOptions| - SegmentationMaskCallback, - callback?: SegmentationMaskCallback): void { + InteractiveSegmenterCallback, + callback?: InteractiveSegmenterCallback): void { const imageProcessingOptions = typeof imageProcessingOptionsOrCallback !== 'function' ? imageProcessingOptionsOrCallback : {}; - - this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + const userCallback = + typeof imageProcessingOptionsOrCallback === 'function' ? imageProcessingOptionsOrCallback : callback!; + this.reset(); this.processRenderData(roi, this.getSynctheticTimestamp()); this.processImageData(image, imageProcessingOptions); - this.userCallback = () => {}; + userCallback(this.result); + } + + private reset(): void { + this.result = {width: 0, height: 0}; } /** Updates the MediaPipe graph configuration. */ @@ -239,7 +262,6 @@ export class InteractiveSegmenter extends VisionTaskRunner { graphConfig.addInputStream(IMAGE_IN_STREAM); graphConfig.addInputStream(ROI_IN_STREAM); graphConfig.addInputStream(NORM_RECT_IN_STREAM); - graphConfig.addOutputStream(IMAGE_OUT_STREAM); const calculatorOptions = new CalculatorOptions(); calculatorOptions.setExtension( @@ -250,24 +272,47 @@ export class InteractiveSegmenter extends VisionTaskRunner { segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM); segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM); - segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM); segmenterNode.setOptions(calculatorOptions); graphConfig.addNode(segmenterNode); - this.graphRunner.attachImageVectorListener( - IMAGE_OUT_STREAM, (masks, timestamp) => { - if (masks.length === 0) { - this.userCallback([], 0, 0); - } else { - this.userCallback( - masks.map(m => m.data), masks[0].width, masks[0].height); - } - this.setLatestOutputTimestamp(timestamp); - }); - this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => { - this.setLatestOutputTimestamp(timestamp); - }); + if (this.outputConfidenceMasks) { + graphConfig.addOutputStream(CONFIDENCE_MASKS_STREAM); + segmenterNode.addOutputStream( + 'CONFIDENCE_MASKS:' + CONFIDENCE_MASKS_STREAM); + + this.graphRunner.attachImageVectorListener( + CONFIDENCE_MASKS_STREAM, (masks, timestamp) => { + this.result.confidenceMasks = masks.map(m => m.data); + if (masks.length >= 0) { + this.result.width = masks[0].width; + this.result.height = masks[0].height; + } + + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CONFIDENCE_MASKS_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } + + if (this.outputCategoryMask) { + graphConfig.addOutputStream(CATEGORY_MASK_STREAM); + segmenterNode.addOutputStream('CATEGORY_MASK:' + CATEGORY_MASK_STREAM); + + this.graphRunner.attachImageListener( + CATEGORY_MASK_STREAM, (mask, timestamp) => { + this.result.categoryMask = mask.data; + this.result.width = mask.width; + this.result.height = mask.height; + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + CATEGORY_MASK_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + } const binaryGraph = graphConfig.serializeBinary(); this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts index beb43cd81..269403d97 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts @@ -19,18 +19,9 @@ import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options' /** Options to configure the MediaPipe Interactive Segmenter Task */ export interface InteractiveSegmenterOptions extends TaskRunnerOptions { - /** - * The output type of segmentation results. - * - * The two supported modes are: - * - Category Mask: Gives a single output mask where each pixel represents - * the class which the pixel in the original image was - * predicted to belong to. - * - Confidence Mask: Gives a list of output masks (one for each class). For - * each mask, the pixel represents the prediction - * confidence, usually in the [0.0, 0.1] range. - * - * Defaults to `CATEGORY_MASK`. - */ - outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; + /** Whether to output confidence masks. Defaults to true. */ + outputConfidenceMasks?: boolean|undefined; + + /** Whether to output the category masks. Defaults to false. */ + outputCategoryMask?: boolean|undefined; } diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts new file mode 100644 index 000000000..f7e1f3a19 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_result.d.ts @@ -0,0 +1,37 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The output result of InteractiveSegmenter. */ +export declare interface InteractiveSegmenterResult { + /** + * Multiple masks as Float32Arrays or WebGLTextures where, for each mask, each + * pixel represents the prediction confidence, usually in the [0, 1] range. + */ + confidenceMasks?: Float32Array[]|WebGLTexture[]; + + /** + * A category mask as a Uint8ClampedArray or WebGLTexture where each + * pixel represents the class which the pixel in the original image was + * predicted to belong to. + */ + categoryMask?: Uint8ClampedArray|WebGLTexture; + + /** The width of the masks. */ + width: number; + + /** The height of the masks. */ + height: number; +} diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts index d6e3a97a5..884be032d 100644 --- a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -18,7 +18,7 @@ import 'jasmine'; // Placeholder for internal dependency on encodeByteArray import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; -import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph} from '../../../../tasks/web/core/task_runner_test_utils'; import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; @@ -37,7 +37,9 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements graph: CalculatorGraphConfig|undefined; fakeWasmModule: SpyWasmModule; - imageVectorListener: + categoryMaskListener: + ((images: WasmImage, timestamp: number) => void)|undefined; + confidenceMasksListener: ((images: WasmImage[], timestamp: number) => void)|undefined; lastRoi?: RenderDataProto; @@ -46,11 +48,16 @@ class InteractiveSegmenterFake extends InteractiveSegmenter implements this.fakeWasmModule = this.graphRunner.wasmModule as unknown as SpyWasmModule; - this.attachListenerSpies[0] = + this.attachListenerSpies[0] = spyOn(this.graphRunner, 'attachImageListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('category_mask'); + this.categoryMaskListener = listener; + }); + this.attachListenerSpies[1] = spyOn(this.graphRunner, 'attachImageVectorListener') .and.callFake((stream, listener) => { - expect(stream).toEqual('image_out'); - this.imageVectorListener = listener; + expect(stream).toEqual('confidence_masks'); + this.confidenceMasksListener = listener; }); spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); @@ -79,17 +86,21 @@ describe('InteractiveSegmenter', () => { it('initializes graph', async () => { verifyGraph(interactiveSegmenter); - verifyListenersRegistered(interactiveSegmenter); + + // Verify default options + expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); }); it('reloads graph when settings are changed', async () => { - await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); - verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]); - verifyListenersRegistered(interactiveSegmenter); + await interactiveSegmenter.setOptions( + {outputConfidenceMasks: true, outputCategoryMask: false}); + expect(interactiveSegmenter.categoryMaskListener).not.toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]); - verifyListenersRegistered(interactiveSegmenter); + await interactiveSegmenter.setOptions( + {outputConfidenceMasks: false, outputCategoryMask: true}); + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); }); it('can use custom models', async () => { @@ -115,23 +126,6 @@ describe('InteractiveSegmenter', () => { ]); }); - - describe('setOptions()', () => { - const fieldPath = ['segmenterOptions', 'outputType']; - - it(`can set outputType`, async () => { - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [fieldPath, 2]); - }); - - it(`can clear outputType`, async () => { - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); - verifyGraph(interactiveSegmenter, [fieldPath, 2]); - await interactiveSegmenter.setOptions({outputType: undefined}); - verifyGraph(interactiveSegmenter, [fieldPath, 1]); - }); - }); - it('doesn\'t support region of interest', () => { expect(() => { interactiveSegmenter.segment( @@ -153,60 +147,99 @@ describe('InteractiveSegmenter', () => { interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); }); - it('supports category masks', (done) => { + it('supports category mask', async () => { const mask = new Uint8ClampedArray([1, 2, 3, 4]); + await interactiveSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: false}); + // Pass the test data to our listener interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(interactiveSegmenter); - interactiveSegmenter.imageVectorListener!( - [ - {data: mask, width: 2, height: 2}, - ], - /* timestamp= */ 1337); + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); + interactiveSegmenter.categoryMaskListener! + ({data: mask, width: 2, height: 2}, + /* timestamp= */ 1337); }); // Invoke the image segmenter - interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, (masks, width, height) => { - expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) - .toHaveBeenCalled(); - expect(masks).toHaveSize(1); - expect(masks[0]).toEqual(mask); - expect(width).toEqual(2); - expect(height).toEqual(2); - done(); - }); + return new Promise(resolve => { + interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(result.categoryMask).toEqual(mask); + expect(result.confidenceMasks).not.toBeDefined(); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); }); it('supports confidence masks', async () => { const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); - await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + await interactiveSegmenter.setOptions( + {outputCategoryMask: false, outputConfidenceMasks: true}); // Pass the test data to our listener interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { - verifyListenersRegistered(interactiveSegmenter); - interactiveSegmenter.imageVectorListener!( + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); + interactiveSegmenter.confidenceMasksListener!( [ {data: mask1, width: 2, height: 2}, {data: mask2, width: 2, height: 2}, ], 1337); }); + return new Promise(resolve => { + // Invoke the image segmenter + interactiveSegmenter.segment({} as HTMLImageElement, ROI, result => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(result.categoryMask).not.toBeDefined(); + expect(result.confidenceMasks).toEqual([mask1, mask2]); + expect(result.width).toEqual(2); + expect(result.height).toEqual(2); + resolve(); + }); + }); + }); + + it('supports combined category and confidence masks', async () => { + const categoryMask = new Uint8ClampedArray([1, 0]); + const confidenceMask1 = new Float32Array([0.0, 1.0]); + const confidenceMask2 = new Float32Array([1.0, 0.0]); + + await interactiveSegmenter.setOptions( + {outputCategoryMask: true, outputConfidenceMasks: true}); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(interactiveSegmenter.categoryMaskListener).toBeDefined(); + expect(interactiveSegmenter.confidenceMasksListener).toBeDefined(); + interactiveSegmenter.categoryMaskListener! + ({data: categoryMask, width: 1, height: 1}, 1337); + interactiveSegmenter.confidenceMasksListener!( + [ + {data: confidenceMask1, width: 1, height: 1}, + {data: confidenceMask2, width: 1, height: 1}, + ], + 1337); + }); return new Promise(resolve => { // Invoke the image segmenter interactiveSegmenter.segment( - {} as HTMLImageElement, ROI, (masks, width, height) => { + {} as HTMLImageElement, ROI, result => { expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) .toHaveBeenCalled(); - expect(masks).toHaveSize(2); - expect(masks[0]).toEqual(mask1); - expect(masks[1]).toEqual(mask2); - expect(width).toEqual(2); - expect(height).toEqual(2); + expect(result.categoryMask).toEqual(categoryMask); + expect(result.confidenceMasks).toEqual([ + confidenceMask1, confidenceMask2 + ]); + expect(result.width).toEqual(1); + expect(result.height).toEqual(1); resolve(); }); }); diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index 3db579f59..92cae43fb 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -16,6 +16,7 @@ export * from '../../../tasks/web/core/fileset_resolver'; export * from '../../../tasks/web/vision/core/drawing_utils'; +export * from '../../../tasks/web/vision/core/image'; export * from '../../../tasks/web/vision/face_detector/face_detector'; export * from '../../../tasks/web/vision/face_landmarker/face_landmarker'; export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; diff --git a/mediapipe/util/annotation_renderer.cc b/mediapipe/util/annotation_renderer.cc index 671f47505..5188da896 100644 --- a/mediapipe/util/annotation_renderer.cc +++ b/mediapipe/util/annotation_renderer.cc @@ -56,8 +56,8 @@ bool NormalizedtoPixelCoordinates(double normalized_x, double normalized_y, VLOG(1) << "Normalized coordinates must be between 0.0 and 1.0"; } - *x_px = static_cast(round(normalized_x * image_width)); - *y_px = static_cast(round(normalized_y * image_height)); + *x_px = static_cast(round(normalized_x * image_width)); + *y_px = static_cast(round(normalized_y * image_height)); return true; } diff --git a/mediapipe/util/cpu_util.cc b/mediapipe/util/cpu_util.cc index c1be9793b..052eabb85 100644 --- a/mediapipe/util/cpu_util.cc +++ b/mediapipe/util/cpu_util.cc @@ -43,7 +43,7 @@ ABSL_FLAG(std::string, system_cpu_max_freq_file, namespace mediapipe { namespace { -constexpr uint32 kBufferLength = 64; +constexpr uint32_t kBufferLength = 64; absl::StatusOr GetFilePath(int cpu) { if (!absl::StrContains(absl::GetFlag(FLAGS_system_cpu_max_freq_file), "$0")) { @@ -54,7 +54,7 @@ absl::StatusOr GetFilePath(int cpu) { return absl::Substitute(absl::GetFlag(FLAGS_system_cpu_max_freq_file), cpu); } -absl::StatusOr GetCpuMaxFrequency(int cpu) { +absl::StatusOr GetCpuMaxFrequency(int cpu) { auto path_or_status = GetFilePath(cpu); if (!path_or_status.ok()) { return path_or_status.status(); @@ -65,7 +65,7 @@ absl::StatusOr GetCpuMaxFrequency(int cpu) { char buffer[kBufferLength]; file.getline(buffer, kBufferLength); file.close(); - uint64 frequency; + uint64_t frequency; if (absl::SimpleAtoi(buffer, &frequency)) { return frequency; } else { @@ -79,7 +79,7 @@ absl::StatusOr GetCpuMaxFrequency(int cpu) { } std::set InferLowerOrHigherCoreIds(bool lower) { - std::vector> cpu_freq_pairs; + std::vector> cpu_freq_pairs; for (int cpu = 0; cpu < NumCPUCores(); ++cpu) { auto freq_or_status = GetCpuMaxFrequency(cpu); if (freq_or_status.ok()) { @@ -90,12 +90,12 @@ std::set InferLowerOrHigherCoreIds(bool lower) { return {}; } - absl::c_sort(cpu_freq_pairs, [lower](const std::pair& left, - const std::pair& right) { + absl::c_sort(cpu_freq_pairs, [lower](const std::pair& left, + const std::pair& right) { return (lower && left.second < right.second) || (!lower && left.second > right.second); }); - uint64 edge_freq = cpu_freq_pairs[0].second; + uint64_t edge_freq = cpu_freq_pairs[0].second; std::set inferred_cores; for (const auto& cpu_freq_pair : cpu_freq_pairs) { diff --git a/mediapipe/util/image_frame_util.cc b/mediapipe/util/image_frame_util.cc index a3a038b00..bf2773fdc 100644 --- a/mediapipe/util/image_frame_util.cc +++ b/mediapipe/util/image_frame_util.cc @@ -89,12 +89,12 @@ void ImageFrameToYUVImage(const ImageFrame& image_frame, YUVImage* yuv_image) { const int uv_stride = (uv_width + 15) & ~15; const int y_size = y_stride * height; const int uv_size = uv_stride * uv_height; - uint8* data = - reinterpret_cast(aligned_malloc(y_size + uv_size * 2, 16)); + uint8_t* data = + reinterpret_cast(aligned_malloc(y_size + uv_size * 2, 16)); std::function deallocate = [data]() { aligned_free(data); }; - uint8* y = data; - uint8* u = y + y_size; - uint8* v = u + uv_size; + uint8_t* y = data; + uint8_t* u = y + y_size; + uint8_t* v = u + uv_size; yuv_image->Initialize(libyuv::FOURCC_I420, deallocate, // y, y_stride, // u, uv_stride, // @@ -123,10 +123,11 @@ void ImageFrameToYUVNV12Image(const ImageFrame& image_frame, const int uv_stride = y_stride; const int uv_height = (height + 1) / 2; const int uv_size = uv_stride * uv_height; - uint8* data = reinterpret_cast(aligned_malloc(y_size + uv_size, 16)); + uint8_t* data = + reinterpret_cast(aligned_malloc(y_size + uv_size, 16)); std::function deallocate = [data] { aligned_free(data); }; - uint8* y = data; - uint8* uv = y + y_size; + uint8_t* y = data; + uint8_t* uv = y + y_size; yuv_nv12_image->Initialize(libyuv::FOURCC_NV12, deallocate, y, y_stride, uv, uv_stride, nullptr, 0, width, height); const int rv = libyuv::I420ToNV12( @@ -210,44 +211,44 @@ void YUVImageToImageFrameFromFormat(const YUVImage& yuv_image, } } -void SrgbToMpegYCbCr(const uint8 r, const uint8 g, const uint8 b, // - uint8* y, uint8* cb, uint8* cr) { +void SrgbToMpegYCbCr(const uint8_t r, const uint8_t g, const uint8_t b, // + uint8_t* y, uint8_t* cb, uint8_t* cr) { // ITU-R BT.601 conversion from sRGB to YCbCr. // FastIntRound is used rather than SafeRound since the possible // range of values is [16,235] for Y and [16,240] for Cb and Cr and we // don't care about the rounding direction for values exactly between // two integers. - *y = static_cast( + *y = static_cast( mediapipe::MathUtil::FastIntRound(16.0 + // 65.481 * r / 255.0 + // 128.553 * g / 255.0 + // 24.966 * b / 255.0)); - *cb = static_cast( + *cb = static_cast( mediapipe::MathUtil::FastIntRound(128.0 + // -37.797 * r / 255.0 + // -74.203 * g / 255.0 + // 112.0 * b / 255.0)); - *cr = static_cast( + *cr = static_cast( mediapipe::MathUtil::FastIntRound(128.0 + // 112.0 * r / 255.0 + // -93.786 * g / 255.0 + // -18.214 * b / 255.0)); } -void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, // - uint8* r, uint8* g, uint8* b) { +void MpegYCbCrToSrgb(const uint8_t y, const uint8_t cb, const uint8_t cr, // + uint8_t* r, uint8_t* g, uint8_t* b) { // ITU-R BT.601 conversion from YCbCr to sRGB // Use SafeRound since many MPEG YCbCr values do not correspond directly // to an sRGB value. - *r = mediapipe::MathUtil::SafeRound( // - 255.0 / 219.0 * (y - 16.0) + // + *r = mediapipe::MathUtil::SafeRound( // + 255.0 / 219.0 * (y - 16.0) + // 255.0 / 112.0 * 0.701 * (cr - 128.0)); - *g = mediapipe::MathUtil::SafeRound( + *g = mediapipe::MathUtil::SafeRound( 255.0 / 219.0 * (y - 16.0) - // 255.0 / 112.0 * 0.886 * 0.114 / 0.587 * (cb - 128.0) - // 255.0 / 112.0 * 0.701 * 0.299 / 0.587 * (cr - 128.0)); - *b = mediapipe::MathUtil::SafeRound( // - 255.0 / 219.0 * (y - 16.0) + // + *b = mediapipe::MathUtil::SafeRound( // + 255.0 / 219.0 * (y - 16.0) + // 255.0 / 112.0 * 0.886 * (cb - 128.0)); } @@ -260,15 +261,15 @@ void MpegYCbCrToSrgb(const uint8 y, const uint8 cb, const uint8 cr, // cv::Mat GetSrgbToLinearRgb16Lut() { cv::Mat lut(1, 256, CV_16UC1); - uint16* ptr = lut.ptr(); + uint16_t* ptr = lut.ptr(); constexpr double kUint8Max = 255.0; constexpr double kUint16Max = 65535.0; for (int i = 0; i < 256; ++i) { if (i < 0.04045 * kUint8Max) { - ptr[i] = static_cast( + ptr[i] = static_cast( (static_cast(i) / kUint8Max / 12.92) * kUint16Max + .5); } else { - ptr[i] = static_cast( + ptr[i] = static_cast( pow((static_cast(i) / kUint8Max + 0.055) / 1.055, 2.4) * kUint16Max + .5); @@ -279,15 +280,15 @@ cv::Mat GetSrgbToLinearRgb16Lut() { cv::Mat GetLinearRgb16ToSrgbLut() { cv::Mat lut(1, 65536, CV_8UC1); - uint8* ptr = lut.ptr(); + uint8_t* ptr = lut.ptr(); constexpr double kUint8Max = 255.0; constexpr double kUint16Max = 65535.0; for (int i = 0; i < 65536; ++i) { if (i < 0.0031308 * kUint16Max) { - ptr[i] = static_cast( + ptr[i] = static_cast( (static_cast(i) / kUint16Max * 12.92) * kUint8Max + .5); } else { - ptr[i] = static_cast( + ptr[i] = static_cast( (1.055 * pow(static_cast(i) / kUint16Max, 1.0 / 2.4) - .055) * kUint8Max + .5); @@ -306,13 +307,13 @@ void LinearRgb16ToSrgb(const cv::Mat& source, cv::Mat* destination) { destination->create(source.size(), CV_8UC(source.channels())); static const cv::Mat kLut = GetLinearRgb16ToSrgbLut(); - const uint8* lookup_table_ptr = kLut.ptr(); + const uint8_t* lookup_table_ptr = kLut.ptr(); const int num_channels = source.channels(); for (int row = 0; row < source.rows; ++row) { for (int col = 0; col < source.cols; ++col) { for (int channel = 0; channel < num_channels; ++channel) { - uint8* ptr = destination->ptr(row); - const uint16* ptr16 = source.ptr(row); + uint8_t* ptr = destination->ptr(row); + const uint16_t* ptr16 = source.ptr(row); ptr[col * num_channels + channel] = lookup_table_ptr[ptr16[col * num_channels + channel]]; } diff --git a/mediapipe/util/image_test_utils.cc b/mediapipe/util/image_test_utils.cc index 815666985..77b755953 100644 --- a/mediapipe/util/image_test_utils.cc +++ b/mediapipe/util/image_test_utils.cc @@ -43,14 +43,14 @@ mediapipe::ImageFormat::Format GetImageFormat(int image_channels) { Packet MakeImageFramePacket(cv::Mat input, int timestamp) { ImageFrame input_image(GetImageFormat(input.channels()), input.cols, - input.rows, input.step, input.data, [](uint8*) {}); + input.rows, input.step, input.data, [](uint8_t*) {}); return MakePacket(std::move(input_image)).At(Timestamp(0)); } Packet MakeImagePacket(cv::Mat input, int timestamp) { mediapipe::Image input_image(std::make_shared( GetImageFormat(input.channels()), input.cols, input.rows, input.step, - input.data, [](uint8*) {})); + input.data, [](uint8_t*) {})); return MakePacket(std::move(input_image)).At(Timestamp(0)); } diff --git a/mediapipe/util/label_map_util.cc b/mediapipe/util/label_map_util.cc index 914a2ba76..eb909349d 100644 --- a/mediapipe/util/label_map_util.cc +++ b/mediapipe/util/label_map_util.cc @@ -25,7 +25,7 @@ namespace mediapipe { -absl::StatusOr> BuildLabelMapFromFiles( +absl::StatusOr> BuildLabelMapFromFiles( absl::string_view labels_file_contents, absl::string_view display_names_file) { if (labels_file_contents.empty()) { @@ -68,7 +68,7 @@ absl::StatusOr> BuildLabelMapFromFiles( label_map_items[i].set_display_name(display_names[i]); } } - proto_ns::Map label_map; + proto_ns::Map label_map; for (int i = 0; i < label_map_items.size(); ++i) { label_map[i] = label_map_items[i]; } diff --git a/mediapipe/util/pose_util.cc b/mediapipe/util/pose_util.cc index a9d2e6158..3a9c1e97b 100644 --- a/mediapipe/util/pose_util.cc +++ b/mediapipe/util/pose_util.cc @@ -107,6 +107,10 @@ const int kFaceMeshFaceOval[36][2] = { {152, 148}, {148, 176}, {176, 149}, {149, 150}, {150, 136}, {136, 172}, {172, 58}, {58, 132}, {132, 93}, {93, 234}, {234, 127}, {127, 162}, {162, 21}, {21, 54}, {54, 103}, {103, 67}, {67, 109}, {109, 10}}; + +const cv::Scalar kRightEyeColor = cv::Scalar(255.0, 48.0, 48.0); +const cv::Scalar kLeftEyeColor = cv::Scalar(48.0, 255.0, 48.0); +const cv::Scalar kFaceContourColor = cv::Scalar(224.0, 224.0, 224.0); } // namespace namespace mediapipe { @@ -180,49 +184,48 @@ void DrawFace(const mediapipe::NormalizedLandmarkList& face, bool flip_y, constexpr int draw_line_width = 2; for (int j = 0; j < 36; ++j) { cv::line(*image, landmarks[kFaceMeshFaceOval[j][0]], - landmarks[kFaceMeshFaceOval[j][1]], cv::Scalar(224, 224, 224), + landmarks[kFaceMeshFaceOval[j][1]], kFaceContourColor, draw_line_width); } for (int j = 0; j < 40; ++j) { cv::line(*image, landmarks[kFaceMeshLips[j][0]], - landmarks[kFaceMeshLips[j][1]], cv::Scalar(224, 224, 224), + landmarks[kFaceMeshLips[j][1]], kFaceContourColor, draw_line_width); } for (int j = 0; j < 16; ++j) { cv::line(*image, landmarks[kFaceMeshLeftEye[j][0]], - landmarks[kFaceMeshLeftEye[j][1]], cv::Scalar(48, 255, 48), - draw_line_width); + landmarks[kFaceMeshLeftEye[j][1]], kLeftEyeColor, draw_line_width); } for (int j = 0; j < 8; ++j) { cv::line(*image, landmarks[kFaceMeshLeftEyebrow[j][0]], - landmarks[kFaceMeshLeftEyebrow[j][1]], cv::Scalar(48, 255, 48), + landmarks[kFaceMeshLeftEyebrow[j][1]], kLeftEyeColor, draw_line_width); } for (int j = 0; j < 4; ++j) { cv::line(*image, landmarks[kFaceMeshLeftIris[j][0]], - landmarks[kFaceMeshLeftIris[j][1]], cv::Scalar(48, 255, 48), + landmarks[kFaceMeshLeftIris[j][1]], kLeftEyeColor, draw_line_width); } for (int j = 0; j < 16; ++j) { cv::line(*image, landmarks[kFaceMeshRightEye[j][0]], - landmarks[kFaceMeshRightEye[j][1]], cv::Scalar(48, 48, 255), + landmarks[kFaceMeshRightEye[j][1]], kRightEyeColor, draw_line_width); } for (int j = 0; j < 8; ++j) { cv::line(*image, landmarks[kFaceMeshRightEyebrow[j][0]], - landmarks[kFaceMeshRightEyebrow[j][1]], cv::Scalar(48, 48, 255), + landmarks[kFaceMeshRightEyebrow[j][1]], kRightEyeColor, draw_line_width); } for (int j = 0; j < 4; ++j) { cv::line(*image, landmarks[kFaceMeshRightIris[j][0]], - landmarks[kFaceMeshRightIris[j][1]], cv::Scalar(48, 48, 255), + landmarks[kFaceMeshRightIris[j][1]], kRightEyeColor, draw_line_width); } } diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index a0bbdbf3f..59663c9ba 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -125,7 +125,7 @@ cc_library_with_tflite( srcs = ["tflite_model_loader.cc"], hdrs = ["tflite_model_loader.h"], tflite_deps = [ - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", + "@org_tensorflow//tensorflow/lite:framework_stable", ], visibility = ["//visibility:public"], deps = [ diff --git a/mediapipe/util/tflite/tflite_model_loader.cc b/mediapipe/util/tflite/tflite_model_loader.cc index 86fc260bb..a2a3cc2be 100644 --- a/mediapipe/util/tflite/tflite_model_loader.cc +++ b/mediapipe/util/tflite/tflite_model_loader.cc @@ -19,7 +19,7 @@ namespace mediapipe { -using FlatBufferModel = ::tflite_shims::FlatBufferModel; +using FlatBufferModel = ::tflite::FlatBufferModel; absl::StatusOr> TfLiteModelLoader::LoadFromPath( const std::string& path) { diff --git a/mediapipe/util/tflite/tflite_model_loader.h b/mediapipe/util/tflite/tflite_model_loader.h index 8c630ec8d..65bd9ba72 100644 --- a/mediapipe/util/tflite/tflite_model_loader.h +++ b/mediapipe/util/tflite/tflite_model_loader.h @@ -22,13 +22,13 @@ #include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" -#include "tensorflow/lite/core/shims/cc/model.h" +#include "tensorflow/lite/model.h" namespace mediapipe { // Represents a TfLite model as a FlatBuffer. using TfLiteModelPtr = - std::unique_ptr>; + std::unique_ptr>; class TfLiteModelLoader { public: diff --git a/mediapipe/util/tracking/tracked_detection_manager.cc b/mediapipe/util/tracking/tracked_detection_manager.cc index 7da207682..77aab3107 100644 --- a/mediapipe/util/tracking/tracked_detection_manager.cc +++ b/mediapipe/util/tracking/tracked_detection_manager.cc @@ -21,7 +21,6 @@ namespace { -using ::mediapipe::NormalizedRect; using mediapipe::TrackedDetection; // Checks if a point is out of view. diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 578577cf0..0115312b4 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -14,6 +14,7 @@ import {isWebKit} from '../../web/graph_runner/platform_utils'; */ export declare interface FileLocator { locateFile: (filename: string) => string; + mainScriptUrlOrBlob?: string; } /** @@ -1222,7 +1223,11 @@ export async function createMediaPipeLib( // self.Module and a fileLocator, we manually merge them into self.Module and // use that. TODO: Remove this when asset scripts are fixed. if (self.Module && fileLocator) { - (self.Module as FileLocator).locateFile = fileLocator.locateFile; + const moduleFileLocator = self.Module as FileLocator; + moduleFileLocator.locateFile = fileLocator.locateFile; + if (fileLocator.mainScriptUrlOrBlob) { + moduleFileLocator.mainScriptUrlOrBlob = fileLocator.mainScriptUrlOrBlob; + } } // TODO: Ensure that fileLocator is passed in by all users // and make it required diff --git a/third_party/apple_frameworks/BUILD b/third_party/apple_frameworks/BUILD index 62f91b515..3f91b7232 100644 --- a/third_party/apple_frameworks/BUILD +++ b/third_party/apple_frameworks/BUILD @@ -76,3 +76,13 @@ cc_library( name = "QuartzCore", linkopts = ["-framework QuartzCore"], ) + +cc_library( + name = "CoreAudio", + linkopts = ["-framework CoreAudio"], +) + +cc_library( + name = "MediaToolbox", + linkopts = ["-framework MediaToolbox"], +) diff --git a/third_party/darts_clone.BUILD b/third_party/darts_clone.BUILD new file mode 100644 index 000000000..a15c2d68d --- /dev/null +++ b/third_party/darts_clone.BUILD @@ -0,0 +1,29 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Description: +# Darts-clone is a clone of Darts (Double-ARray Trie System). + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "darts_clone", + hdrs = [ + "include/darts.h", + ], +) diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index 2d9cfc1fe..7264c1b1c 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -307,7 +307,7 @@ def external_files(): http_file( name = "com_google_mediapipe_expected_pose_landmarks_prototxt", sha256 = "eed8dfa169b0abee60cde01496599b0bc75d91a82594a1bdf59be2f76f45d7f5", - urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=1681244232522990"], + urls = ["https://storage.googleapis.com/mediapipe-assets/expected_pose_landmarks.prototxt?generation=16812442325229901681244235071100"], ) http_file( @@ -996,8 +996,8 @@ def external_files(): http_file( name = "com_google_mediapipe_pose_landmarks_pbtxt", - sha256 = "305a71fbff83e270a5dbd81fb7cf65203f56e0b1caba8ea42edc16c6e8a2ba18", - urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681244254964356"], + sha256 = "69c79cdf3964d7819776eab1172e47e70684139d72a6d7edcbdd62dbb2ca5527", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_landmarks.pbtxt?generation=1681425322701589"], ) http_file( diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel index 53cecc734..d5264a026 100644 --- a/third_party/flatbuffers/BUILD.bazel +++ b/third_party/flatbuffers/BUILD.bazel @@ -45,12 +45,16 @@ filegroup( "include/flatbuffers/bfbs_generator.h", "include/flatbuffers/buffer.h", "include/flatbuffers/buffer_ref.h", + "include/flatbuffers/code_generator.h", "include/flatbuffers/code_generators.h", "include/flatbuffers/default_allocator.h", "include/flatbuffers/detached_buffer.h", "include/flatbuffers/flatbuffer_builder.h", "include/flatbuffers/flatbuffers.h", + "include/flatbuffers/flatc.h", + "include/flatbuffers/flex_flat_util.h", "include/flatbuffers/flexbuffers.h", + "include/flatbuffers/grpc.h", "include/flatbuffers/hash.h", "include/flatbuffers/idl.h", "include/flatbuffers/minireflect.h", diff --git a/third_party/flatbuffers/workspace.bzl b/third_party/flatbuffers/workspace.bzl index 703cb0536..02247268b 100644 --- a/third_party/flatbuffers/workspace.bzl +++ b/third_party/flatbuffers/workspace.bzl @@ -5,11 +5,11 @@ load("//third_party:repo.bzl", "third_party_http_archive") def repo(): third_party_http_archive( name = "flatbuffers", - strip_prefix = "flatbuffers-2.0.6", - sha256 = "e2dc24985a85b278dd06313481a9ca051d048f9474e0f199e372fea3ea4248c9", + strip_prefix = "flatbuffers-23.1.21", + sha256 = "d84cb25686514348e615163b458ae0767001b24b42325f426fd56406fd384238", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v2.0.6.tar.gz", - "https://github.com/google/flatbuffers/archive/v2.0.6.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/flatbuffers/archive/v23.1.21.tar.gz", + "https://github.com/google/flatbuffers/archive/v23.1.21.tar.gz", ], build_file = "//third_party/flatbuffers:BUILD.bazel", delete = ["build_defs.bzl", "BUILD.bazel"], diff --git a/third_party/wasm_files.bzl b/third_party/wasm_files.bzl index 148b5970f..a484d2f82 100644 --- a/third_party/wasm_files.bzl +++ b/third_party/wasm_files.bzl @@ -12,72 +12,72 @@ def wasm_files(): http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_js", - sha256 = "0eca68e2291a548b734bcab5db4c9e6b997e852ea7e19228003b9e2a78c7c646", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681328323089931"], + sha256 = "b810de53d7ccf991b9c70fcdf7e88b5c3f2942ae766436f22be48159b6a7e687", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.js?generation=1681849488227617"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_internal_wasm", - sha256 = "69bc95af5b783b510ec1842d6fb9594254907d8e1334799c5753164878a7dcac", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681328325829340"], + sha256 = "26d91147e5c6c8a92e0a4ebf59599068a3cff6108847b793ef33ac23e98eddb9", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_internal.wasm?generation=1681849491546937"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_js", - sha256 = "88a0176cc80d6a1eb175a5105df705cf8b8684cf13f6db0a264af0b67b65a22a", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681328328330829"], + sha256 = "b38e37b3024692558eaaba159921fedd3297d1a09bba1c16a06fed327845b0bd", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.js?generation=1681849494099698"], ) http_file( name = "com_google_mediapipe_wasm_audio_wasm_nosimd_internal_wasm", - sha256 = "1cc0c3db7d252801be4b090d8bbba61f308cc3dd5efe197319581d3af29495c7", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681328331085637"], + sha256 = "6a8e73d2e926565046e16adf1748f0f8ec5135fafe7eb8b9c83892e64c1a449a", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/audio_wasm_nosimd_internal.wasm?generation=1681849496451970"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_js", - sha256 = "d9cd100b6d330d36f7749fe5fc64a2cdd0abb947a0376e6140784cfb0361a4e2", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681328333442454"], + sha256 = "785cba67b623b1dc66dc3621e97fd6b30edccbb408184a3094d0aa68ddd5becb", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.js?generation=1681849498746265"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_internal_wasm", - sha256 = "30a2fcca630bdad6e99173ea7d0d8c5d7086aedf393d0159fa05bf9d08d4ff65", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681328335803336"], + sha256 = "a858b8a2e8b40e9c936b66566c5aefd396536c4e936459ab9ae7e239621adc14", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_internal.wasm?generation=1681849501370461"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_js", - sha256 = "70ca2bd15c56e0ce7bb10ff2188b4a1f9eafbb657eb9424e4cab8d7b29179871", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681328338162884"], + sha256 = "5292f1442d5e5c037e7cffb78a8c2d71255348ca2c3bd759b314bdbedd5590c2", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.js?generation=1681849503379116"], ) http_file( name = "com_google_mediapipe_wasm_text_wasm_nosimd_internal_wasm", - sha256 = "8221b385905f36a769d7731a0adbe18b681bcb873561890429ca84278c67c3fd", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681328340808115"], + sha256 = "e44b48ab29ee1d8befec804e9a63445c56266b679d19fb476d556ca621f0e493", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/text_wasm_nosimd_internal.wasm?generation=1681849505997020"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_js", - sha256 = "07692acd8202adafebd35dbcd7e2b8e88a76d4a0e6b9229cb3cad59503eeddc7", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681328343147709"], + sha256 = "205855eba70464a92b9d00e90acac15c51a9f76192f900e697304ac6dea8f714", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.js?generation=1681849508414277"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_internal_wasm", - sha256 = "03bf553fa6a768b0d70103a5e7d835b6b37371ff44e201c3392f22e0879737c3", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681328345605574"], + sha256 = "c0cbd0df3adb2a9cd1331d14f522d2bae9f8adc9f1b35f92cbbc4b782b190cef", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_internal.wasm?generation=1681849510936608"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_js", - sha256 = "36697be14f921985eac15d1447ec8a260817b05ade1c9bb3ca7e906e0f047ec0", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681328348025082"], + sha256 = "0969812de4d3573198fa2eba4f5b0a7e97e98f97bd4215d876543f4925e57b84", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.js?generation=1681849513292639"], ) http_file( name = "com_google_mediapipe_wasm_vision_wasm_nosimd_internal_wasm", - sha256 = "103fb145438d61cfecb2e8db3f06b43a5d77a7e3fcea940437fe272227cf2592", - urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681328350709881"], + sha256 = "f2ab62c3f8dabab0a573dadf5c105ff81a03c29c70f091f8cf273ae030c0a86f", + urls = ["https://storage.googleapis.com/mediapipe-assets/wasm/vision_wasm_nosimd_internal.wasm?generation=1681849515999000"], )