Merge branch 'google:master' into face-stylizer-python-add-tests
This commit is contained in:
		
						commit
						a5716c9225
					
				
							
								
								
									
										10
									
								
								WORKSPACE
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								WORKSPACE
									
									
									
									
									
								
							|  | @ -239,6 +239,16 @@ http_archive( | |||
|     repo_mapping = {"@com_google_glog" : "@com_github_glog_glog_no_gflags"}, | ||||
| ) | ||||
| 
 | ||||
| http_archive( | ||||
|     name = "darts_clone", | ||||
|     build_file = "@//third_party:darts_clone.BUILD", | ||||
|     sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c", | ||||
|     strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983", | ||||
|     urls = [ | ||||
|         "https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| http_archive( | ||||
|     name = "org_tensorflow_text", | ||||
|     sha256 = "f64647276f7288d1b1fe4c89581d51404d0ce4ae97f2bcc4c19bd667549adca8", | ||||
|  |  | |||
|  | @ -78,7 +78,7 @@ class ConstantSidePacketCalculator : public CalculatorBase { | |||
|       } else if (packet_options.has_string_value()) { | ||||
|         packet.Set<std::string>(); | ||||
|       } else if (packet_options.has_uint64_value()) { | ||||
|         packet.Set<uint64>(); | ||||
|         packet.Set<uint64_t>(); | ||||
|       } else if (packet_options.has_classification_list_value()) { | ||||
|         packet.Set<ClassificationList>(); | ||||
|       } else if (packet_options.has_landmark_list_value()) { | ||||
|  | @ -112,7 +112,7 @@ class ConstantSidePacketCalculator : public CalculatorBase { | |||
|       } else if (packet_options.has_string_value()) { | ||||
|         packet.Set(MakePacket<std::string>(packet_options.string_value())); | ||||
|       } else if (packet_options.has_uint64_value()) { | ||||
|         packet.Set(MakePacket<uint64>(packet_options.uint64_value())); | ||||
|         packet.Set(MakePacket<uint64_t>(packet_options.uint64_value())); | ||||
|       } else if (packet_options.has_classification_list_value()) { | ||||
|         packet.Set(MakePacket<ClassificationList>( | ||||
|             packet_options.classification_list_value())); | ||||
|  |  | |||
|  | @ -35,14 +35,14 @@ class GateCalculatorTest : public ::testing::Test { | |||
|   } | ||||
| 
 | ||||
|   // Use this when ALLOW/DISALLOW input is provided as a side packet.
 | ||||
|   void RunTimeStep(int64 timestamp, bool stream_payload) { | ||||
|   void RunTimeStep(int64_t timestamp, bool stream_payload) { | ||||
|     runner_->MutableInputs()->Get("", 0).packets.push_back( | ||||
|         MakePacket<bool>(stream_payload).At(Timestamp(timestamp))); | ||||
|     MP_ASSERT_OK(runner_->Run()) << "Calculator execution failed."; | ||||
|   } | ||||
| 
 | ||||
|   // Use this when ALLOW/DISALLOW input is provided as an input stream.
 | ||||
|   void RunTimeStep(int64 timestamp, const std::string& control_tag, | ||||
|   void RunTimeStep(int64_t timestamp, const std::string& control_tag, | ||||
|                    bool control) { | ||||
|     runner_->MutableInputs()->Get("", 0).packets.push_back( | ||||
|         MakePacket<bool>(true).At(Timestamp(timestamp))); | ||||
|  | @ -134,9 +134,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWOptionToTrue) { | |||
|         } | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -159,9 +159,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionSetToFalse) { | |||
|         } | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -175,9 +175,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWOptionNotSet) { | |||
|         output_stream: "test_output" | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -193,9 +193,9 @@ TEST_F(GateCalculatorTest, AllowByALLOWSidePacketSetToTrue) { | |||
|   )"); | ||||
|   runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(true)); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -215,9 +215,9 @@ TEST_F(GateCalculatorTest, AllowByDisallowSidePacketSetToFalse) { | |||
|   )"); | ||||
|   runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(false)); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -237,9 +237,9 @@ TEST_F(GateCalculatorTest, DisallowByALLOWSidePacketSetToFalse) { | |||
|   )"); | ||||
|   runner()->MutableSidePackets()->Tag(kAllowTag) = Adopt(new bool(false)); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -255,9 +255,9 @@ TEST_F(GateCalculatorTest, DisallowByDISALLOWSidePacketSetToTrue) { | |||
|   )"); | ||||
|   runner()->MutableSidePackets()->Tag(kDisallowTag) = Adopt(new bool(true)); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -272,13 +272,13 @@ TEST_F(GateCalculatorTest, Allow) { | |||
|         output_stream: "test_output" | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, "ALLOW", true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, "ALLOW", false); | ||||
|   constexpr int64 kTimestampValue2 = 44; | ||||
|   constexpr int64_t kTimestampValue2 = 44; | ||||
|   RunTimeStep(kTimestampValue2, "ALLOW", true); | ||||
|   constexpr int64 kTimestampValue3 = 45; | ||||
|   constexpr int64_t kTimestampValue3 = 45; | ||||
|   RunTimeStep(kTimestampValue3, "ALLOW", false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -297,13 +297,13 @@ TEST_F(GateCalculatorTest, Disallow) { | |||
|         output_stream: "test_output" | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, "DISALLOW", true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, "DISALLOW", false); | ||||
|   constexpr int64 kTimestampValue2 = 44; | ||||
|   constexpr int64_t kTimestampValue2 = 44; | ||||
|   RunTimeStep(kTimestampValue2, "DISALLOW", true); | ||||
|   constexpr int64 kTimestampValue3 = 45; | ||||
|   constexpr int64_t kTimestampValue3 = 45; | ||||
|   RunTimeStep(kTimestampValue3, "DISALLOW", false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = runner()->Outputs().Get("", 0).packets; | ||||
|  | @ -323,13 +323,13 @@ TEST_F(GateCalculatorTest, AllowWithStateChange) { | |||
|         output_stream: "STATE_CHANGE:state_changed" | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, "ALLOW", false); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, "ALLOW", true); | ||||
|   constexpr int64 kTimestampValue2 = 44; | ||||
|   constexpr int64_t kTimestampValue2 = 44; | ||||
|   RunTimeStep(kTimestampValue2, "ALLOW", true); | ||||
|   constexpr int64 kTimestampValue3 = 45; | ||||
|   constexpr int64_t kTimestampValue3 = 45; | ||||
|   RunTimeStep(kTimestampValue3, "ALLOW", false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = | ||||
|  | @ -379,13 +379,13 @@ TEST_F(GateCalculatorTest, DisallowWithStateChange) { | |||
|         output_stream: "STATE_CHANGE:state_changed" | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, "DISALLOW", true); | ||||
|   constexpr int64 kTimestampValue1 = 43; | ||||
|   constexpr int64_t kTimestampValue1 = 43; | ||||
|   RunTimeStep(kTimestampValue1, "DISALLOW", false); | ||||
|   constexpr int64 kTimestampValue2 = 44; | ||||
|   constexpr int64_t kTimestampValue2 = 44; | ||||
|   RunTimeStep(kTimestampValue2, "DISALLOW", false); | ||||
|   constexpr int64 kTimestampValue3 = 45; | ||||
|   constexpr int64_t kTimestampValue3 = 45; | ||||
|   RunTimeStep(kTimestampValue3, "DISALLOW", true); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = | ||||
|  | @ -432,7 +432,7 @@ TEST_F(GateCalculatorTest, DisallowInitialNoStateTransition) { | |||
|         output_stream: "STATE_CHANGE:state_changed" | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, "DISALLOW", false); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = | ||||
|  | @ -450,7 +450,7 @@ TEST_F(GateCalculatorTest, AllowInitialNoStateTransition) { | |||
|         output_stream: "STATE_CHANGE:state_changed" | ||||
|   )"); | ||||
| 
 | ||||
|   constexpr int64 kTimestampValue0 = 42; | ||||
|   constexpr int64_t kTimestampValue0 = 42; | ||||
|   RunTimeStep(kTimestampValue0, "ALLOW", true); | ||||
| 
 | ||||
|   const std::vector<Packet>& output = | ||||
|  |  | |||
|  | @ -35,7 +35,7 @@ class MatrixToVectorCalculatorTest | |||
|   void SetUp() override { calculator_name_ = "MatrixToVectorCalculator"; } | ||||
| 
 | ||||
|   void AppendInput(const std::vector<float>& column_major_data, | ||||
|                    int64 timestamp) { | ||||
|                    int64_t timestamp) { | ||||
|     ASSERT_EQ(num_input_samples_ * num_input_channels_, | ||||
|               column_major_data.size()); | ||||
|     Eigen::Map<const Matrix> data_map(&column_major_data[0], | ||||
|  |  | |||
|  | @ -51,9 +51,9 @@ class SimpleRunner : public CalculatorRunner { | |||
| 
 | ||||
|   virtual ~SimpleRunner() {} | ||||
| 
 | ||||
|   void SetInput(const std::vector<int64>& timestamp_list) { | ||||
|   void SetInput(const std::vector<int64_t>& timestamp_list) { | ||||
|     MutableInputs()->Index(0).packets.clear(); | ||||
|     for (const int64 ts : timestamp_list) { | ||||
|     for (const int64_t ts : timestamp_list) { | ||||
|       MutableInputs()->Index(0).packets.push_back( | ||||
|           Adopt(new std::string(absl::StrCat("Frame #", ts))) | ||||
|               .At(Timestamp(ts))); | ||||
|  | @ -72,8 +72,8 @@ class SimpleRunner : public CalculatorRunner { | |||
|   } | ||||
| 
 | ||||
|   void CheckOutputTimestamps( | ||||
|       const std::vector<int64>& expected_frames, | ||||
|       const std::vector<int64>& expected_timestamps) const { | ||||
|       const std::vector<int64_t>& expected_frames, | ||||
|       const std::vector<int64_t>& expected_timestamps) const { | ||||
|     EXPECT_EQ(expected_frames.size(), Outputs().Index(0).packets.size()); | ||||
|     EXPECT_EQ(expected_timestamps.size(), Outputs().Index(0).packets.size()); | ||||
|     int count = 0; | ||||
|  | @ -112,7 +112,7 @@ MATCHER_P2(PacketAtTimestamp, payload, timestamp, | |||
|     *result_listener << "at incorrect timestamp = " << arg.Timestamp().Value(); | ||||
|     return false; | ||||
|   } | ||||
|   int64 actual_payload = arg.template Get<int64>(); | ||||
|   int64_t actual_payload = arg.template Get<int64_t>(); | ||||
|   if (actual_payload != payload) { | ||||
|     *result_listener << "with incorrect payload = " << actual_payload; | ||||
|     return false; | ||||
|  | @ -137,18 +137,18 @@ class ReproducibleJitterWithReflectionStrategyForTesting | |||
|   //
 | ||||
|   // An EXPECT will fail if sequence is less than the number requested during
 | ||||
|   // processing.
 | ||||
|   static std::vector<uint64> random_sequence; | ||||
|   static std::vector<uint64_t> random_sequence; | ||||
| 
 | ||||
|  protected: | ||||
|   virtual uint64 GetNextRandom(uint64 n) { | ||||
|   virtual uint64_t GetNextRandom(uint64_t n) { | ||||
|     EXPECT_LT(sequence_index_, random_sequence.size()); | ||||
|     return random_sequence[sequence_index_++] % n; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   int32 sequence_index_ = 0; | ||||
|   int32_t sequence_index_ = 0; | ||||
| }; | ||||
| std::vector<uint64> | ||||
| std::vector<uint64_t> | ||||
|     ReproducibleJitterWithReflectionStrategyForTesting::random_sequence; | ||||
| 
 | ||||
| // PacketResamplerCalculator child class which injects a specified stream
 | ||||
|  | @ -469,7 +469,7 @@ TEST(PacketResamplerCalculatorTest, SetVideoHeader) { | |||
|     } | ||||
|   )pb")); | ||||
| 
 | ||||
|   for (const int64 ts : {0, 5000, 10010, 15001, 19990}) { | ||||
|   for (const int64_t ts : {0, 5000, 10010, 15001, 19990}) { | ||||
|     runner.MutableInputs()->Tag(kDataTag).packets.push_back( | ||||
|         Adopt(new std::string(absl::StrCat("Frame #", ts))).At(Timestamp(ts))); | ||||
|   } | ||||
|  |  | |||
|  | @ -43,8 +43,8 @@ constexpr char kDisallowTag[] = "DISALLOW"; | |||
| 
 | ||||
| // Returns the timestamp values for a vector of Packets.
 | ||||
| // TODO: puth this kind of test util in a common place.
 | ||||
| std::vector<int64> TimestampValues(const std::vector<Packet>& packets) { | ||||
|   std::vector<int64> result; | ||||
| std::vector<int64_t> TimestampValues(const std::vector<Packet>& packets) { | ||||
|   std::vector<int64_t> result; | ||||
|   for (const Packet& packet : packets) { | ||||
|     result.push_back(packet.Timestamp().Value()); | ||||
|   } | ||||
|  | @ -371,7 +371,7 @@ TEST(PreviousLoopbackCalculator, EmptyLoopForever) { | |||
|   for (int main_ts = 0; main_ts < 50; ++main_ts) { | ||||
|     send_packet("in", main_ts); | ||||
|     MP_EXPECT_OK(graph_.WaitUntilIdle()); | ||||
|     std::vector<int64> ts_values = TimestampValues(outputs); | ||||
|     std::vector<int64_t> ts_values = TimestampValues(outputs); | ||||
|     EXPECT_EQ(ts_values.size(), main_ts + 1); | ||||
|     for (int j = 0; j < main_ts + 1; ++j) { | ||||
|       EXPECT_EQ(ts_values[j], j); | ||||
|  |  | |||
|  | @ -121,7 +121,7 @@ absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { | |||
|   if (cc->Outputs().HasTag(kTagAtTimestamp)) { | ||||
|     RET_CHECK_EQ(num_entries + 1, cc->InputSidePackets().NumEntries()) | ||||
|         << "For AT_TIMESTAMP tag, 2 input side packets are required."; | ||||
|     cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64>(); | ||||
|     cc->InputSidePackets().Tag(kTagSideInputTimestamp).Set<int64_t>(); | ||||
|   } else { | ||||
|     RET_CHECK_EQ(num_entries, cc->InputSidePackets().NumEntries()) | ||||
|         << "Same number of input side packets and output streams is required."; | ||||
|  | @ -178,8 +178,8 @@ absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { | |||
|           .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); | ||||
|     } | ||||
|   } else if (cc->Outputs().HasTag(kTagAtTimestamp)) { | ||||
|     int64 timestamp = | ||||
|         cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64>(); | ||||
|     int64_t timestamp = | ||||
|         cc->InputSidePackets().Tag(kTagSideInputTimestamp).Get<int64_t>(); | ||||
|     for (int i = 0; i < cc->Outputs().NumEntries(output_tag_); ++i) { | ||||
|       cc->Outputs() | ||||
|           .Get(output_tag_, i) | ||||
|  |  | |||
|  | @ -64,16 +64,16 @@ REGISTER_CALCULATOR(StringToIntCalculator); | |||
| using StringToUintCalculator = StringToIntCalculatorTemplate<unsigned int>; | ||||
| REGISTER_CALCULATOR(StringToUintCalculator); | ||||
| 
 | ||||
| using StringToInt32Calculator = StringToIntCalculatorTemplate<int32>; | ||||
| using StringToInt32Calculator = StringToIntCalculatorTemplate<int32_t>; | ||||
| REGISTER_CALCULATOR(StringToInt32Calculator); | ||||
| 
 | ||||
| using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32>; | ||||
| using StringToUint32Calculator = StringToIntCalculatorTemplate<uint32_t>; | ||||
| REGISTER_CALCULATOR(StringToUint32Calculator); | ||||
| 
 | ||||
| using StringToInt64Calculator = StringToIntCalculatorTemplate<int64>; | ||||
| using StringToInt64Calculator = StringToIntCalculatorTemplate<int64_t>; | ||||
| REGISTER_CALCULATOR(StringToInt64Calculator); | ||||
| 
 | ||||
| using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64>; | ||||
| using StringToUint64Calculator = StringToIntCalculatorTemplate<uint64_t>; | ||||
| REGISTER_CALCULATOR(StringToUint64Calculator); | ||||
| 
 | ||||
| }  // namespace mediapipe
 | ||||
|  |  | |||
|  | @ -166,7 +166,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> { | |||
|       const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(), | ||||
|                                    frame_ptr->Height(), frame_ptr->WidthStep(), | ||||
|                                    const_cast<uint8_t*>(frame_ptr->PixelData()), | ||||
|                                    [](uint8* data){}); | ||||
|                                    [](uint8_t* data){}); | ||||
|       ASSIGN_OR_RETURN(auto result, | ||||
|                        runner->Run(image_frame, matrix, size, border_mode)); | ||||
|       return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result))); | ||||
|  |  | |||
|  | @ -401,8 +401,8 @@ cc_library_with_tflite( | |||
|     hdrs = ["inference_calculator.h"], | ||||
|     tflite_deps = [ | ||||
|         "//mediapipe/util/tflite:tflite_model_loader", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", | ||||
|         "@org_tensorflow//tensorflow/lite:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":inference_calculator_cc_proto", | ||||
|  | @ -506,7 +506,7 @@ cc_library_with_tflite( | |||
|     name = "tflite_delegate_ptr", | ||||
|     hdrs = ["tflite_delegate_ptr.h"], | ||||
|     tflite_deps = [ | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", | ||||
|         "@org_tensorflow//tensorflow/lite/c:c_api_types", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
|  | @ -517,8 +517,8 @@ cc_library_with_tflite( | |||
|     tflite_deps = [ | ||||
|         ":tflite_delegate_ptr", | ||||
|         "//mediapipe/util/tflite:tflite_model_loader", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite/c:c_api_types", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":inference_runner", | ||||
|  | @ -546,8 +546,8 @@ cc_library( | |||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:c_api_types", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite/c:c_api_types", | ||||
|         "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", | ||||
|     ] + select({ | ||||
|         "//conditions:default": [], | ||||
|  |  | |||
|  | @ -94,8 +94,8 @@ InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) { | |||
|     return kSideInCustomOpResolver(cc).As<tflite::OpResolver>(); | ||||
|   } | ||||
|   return PacketAdopting<tflite::OpResolver>( | ||||
|       std::make_unique<tflite_shims::ops::builtin:: | ||||
|                            BuiltinOpResolverWithoutDefaultDelegates>()); | ||||
|       std::make_unique< | ||||
|           tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>()); | ||||
| } | ||||
| 
 | ||||
| }  // namespace api2
 | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ | |||
| #include "mediapipe/framework/formats/tensor.h" | ||||
| #include "mediapipe/util/tflite/tflite_model_loader.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/kernels/register.h" | ||||
| #include "tensorflow/lite/kernels/register.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace api2 { | ||||
|  | @ -97,8 +97,8 @@ class InferenceCalculator : public NodeIntf { | |||
|   // Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
 | ||||
|   // TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
 | ||||
|   // migration.
 | ||||
|   static constexpr SideInput<tflite_shims::ops::builtin::BuiltinOpResolver>:: | ||||
|       Optional kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; | ||||
|   static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional | ||||
|       kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; | ||||
|   static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{ | ||||
|       "OP_RESOLVER"}; | ||||
|   static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"}; | ||||
|  |  | |||
|  | @ -24,7 +24,7 @@ | |||
| #include "mediapipe/calculators/tensor/inference_calculator_utils.h" | ||||
| #include "mediapipe/calculators/tensor/inference_interpreter_delegate_runner.h" | ||||
| #include "mediapipe/calculators/tensor/inference_runner.h" | ||||
| #include "tensorflow/lite/core/shims/cc/interpreter.h" | ||||
| #include "tensorflow/lite/interpreter.h" | ||||
| #if defined(MEDIAPIPE_ANDROID) | ||||
| #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" | ||||
| #endif  // ANDROID
 | ||||
|  |  | |||
|  | @ -22,9 +22,9 @@ | |||
| #include "mediapipe/framework/formats/tensor.h" | ||||
| #include "mediapipe/framework/mediapipe_profiling.h" | ||||
| #include "mediapipe/framework/port/ret_check.h" | ||||
| #include "tensorflow/lite/core/shims/c/c_api_types.h" | ||||
| #include "tensorflow/lite/core/shims/cc/interpreter.h" | ||||
| #include "tensorflow/lite/core/shims/cc/interpreter_builder.h" | ||||
| #include "tensorflow/lite/c/c_api_types.h" | ||||
| #include "tensorflow/lite/interpreter.h" | ||||
| #include "tensorflow/lite/interpreter_builder.h" | ||||
| #include "tensorflow/lite/string_util.h" | ||||
| 
 | ||||
| #define PERFETTO_TRACK_EVENT_NAMESPACE mediapipe | ||||
|  | @ -33,8 +33,8 @@ namespace mediapipe { | |||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| using Interpreter = ::tflite_shims::Interpreter; | ||||
| using InterpreterBuilder = ::tflite_shims::InterpreterBuilder; | ||||
| using Interpreter = ::tflite::Interpreter; | ||||
| using InterpreterBuilder = ::tflite::InterpreterBuilder; | ||||
| 
 | ||||
| template <typename T> | ||||
| void CopyTensorBufferToInterpreter(const Tensor& input_tensor, | ||||
|  |  | |||
|  | @ -23,8 +23,8 @@ | |||
| #include "mediapipe/calculators/tensor/tflite_delegate_ptr.h" | ||||
| #include "mediapipe/framework/api2/packet.h" | ||||
| #include "mediapipe/util/tflite/tflite_model_loader.h" | ||||
| #include "tensorflow/lite/c/c_api_types.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/c/c_api_types.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| 
 | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ | |||
| #include <functional> | ||||
| #include <memory> | ||||
| 
 | ||||
| #include "tensorflow/lite/core/shims/c/c_api_types.h" | ||||
| #include "tensorflow/lite/c/c_api_types.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| 
 | ||||
|  |  | |||
|  | @ -61,12 +61,12 @@ constexpr char kSessionBundleTag[] = "SESSION_BUNDLE"; | |||
| // overload GPU/TPU/...
 | ||||
| class SimpleSemaphore { | ||||
|  public: | ||||
|   explicit SimpleSemaphore(uint32 initial_count) : count_(initial_count) {} | ||||
|   explicit SimpleSemaphore(uint32_t initial_count) : count_(initial_count) {} | ||||
|   SimpleSemaphore(const SimpleSemaphore&) = delete; | ||||
|   SimpleSemaphore(SimpleSemaphore&&) = delete; | ||||
| 
 | ||||
|   // Acquires the semaphore by certain amount.
 | ||||
|   void Acquire(uint32 amount) { | ||||
|   void Acquire(uint32_t amount) { | ||||
|     mutex_.Lock(); | ||||
|     while (count_ < amount) { | ||||
|       cond_.Wait(&mutex_); | ||||
|  | @ -76,7 +76,7 @@ class SimpleSemaphore { | |||
|   } | ||||
| 
 | ||||
|   // Releases the semaphore by certain amount.
 | ||||
|   void Release(uint32 amount) { | ||||
|   void Release(uint32_t amount) { | ||||
|     mutex_.Lock(); | ||||
|     count_ += amount; | ||||
|     cond_.SignalAll(); | ||||
|  | @ -84,7 +84,7 @@ class SimpleSemaphore { | |||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   uint32 count_; | ||||
|   uint32_t count_; | ||||
|   absl::Mutex mutex_; | ||||
|   absl::CondVar cond_; | ||||
| }; | ||||
|  | @ -488,7 +488,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { | |||
|   // necessary.
 | ||||
|   absl::Status OutputBatch(CalculatorContext* cc, | ||||
|                            std::unique_ptr<InferenceState> inference_state) { | ||||
|     const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     const int64_t start_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     std::vector<std::pair<mediapipe::ProtoString, tf::Tensor>> input_tensors; | ||||
| 
 | ||||
|     for (auto& keyed_tensors : inference_state->input_tensor_batches_) { | ||||
|  | @ -544,7 +544,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { | |||
|           get_session_run_throttle(options_.max_concurrent_session_runs()); | ||||
|       session_run_throttle->Acquire(1); | ||||
|     } | ||||
|     const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     const int64_t run_start_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     tf::Status tf_status; | ||||
|     { | ||||
| #if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) | ||||
|  | @ -562,7 +562,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { | |||
|     // informative error message.
 | ||||
|     RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); | ||||
| 
 | ||||
|     const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     const int64_t run_end_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) | ||||
|         ->IncrementBy(run_end_time - run_start_time); | ||||
|     cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); | ||||
|  | @ -611,7 +611,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { | |||
|     } | ||||
| 
 | ||||
|     // Get end time and report.
 | ||||
|     const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     const int64_t end_time = absl::ToUnixMicros(clock_->TimeNow()); | ||||
|     cc->GetCounter(kTotalUsecsCounterSuffix) | ||||
|         ->IncrementBy(end_time - start_time); | ||||
|     cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) | ||||
|  | @ -650,7 +650,7 @@ class TensorFlowInferenceCalculator : public CalculatorBase { | |||
| 
 | ||||
|   // The static singleton semaphore to throttle concurrent session runs.
 | ||||
|   static SimpleSemaphore* get_session_run_throttle( | ||||
|       int32 max_concurrent_session_runs) { | ||||
|       int32_t max_concurrent_session_runs) { | ||||
|     static SimpleSemaphore* session_run_throttle = | ||||
|         new SimpleSemaphore(max_concurrent_session_runs); | ||||
|     return session_run_throttle; | ||||
|  |  | |||
|  | @ -197,15 +197,15 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { | |||
|     // timestamp and the associated feature. This information is used in process
 | ||||
|     // to output batches of packets in order.
 | ||||
|     timestamps_.clear(); | ||||
|     int64 last_timestamp_seen = Timestamp::PreStream().Value(); | ||||
|     int64_t last_timestamp_seen = Timestamp::PreStream().Value(); | ||||
|     first_timestamp_seen_ = Timestamp::OneOverPostStream().Value(); | ||||
|     for (const auto& map_kv : sequence_->feature_lists().feature_list()) { | ||||
|       if (absl::StrContains(map_kv.first, "/timestamp")) { | ||||
|         LOG(INFO) << "Found feature timestamps: " << map_kv.first | ||||
|                   << " with size: " << map_kv.second.feature_size(); | ||||
|         int64 recent_timestamp = Timestamp::PreStream().Value(); | ||||
|         int64_t recent_timestamp = Timestamp::PreStream().Value(); | ||||
|         for (int i = 0; i < map_kv.second.feature_size(); ++i) { | ||||
|           int64 next_timestamp = | ||||
|           int64_t next_timestamp = | ||||
|               mpms::GetInt64sAt(*sequence_, map_kv.first, i).Get(0); | ||||
|           RET_CHECK_GT(next_timestamp, recent_timestamp) | ||||
|               << "Timestamps must be sequential. If you're seeing this message " | ||||
|  | @ -361,8 +361,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { | |||
|     // any particular call to Process(). At the every end, we output the
 | ||||
|     // poststream packets. If we only have poststream packets,
 | ||||
|     // last_timestamp_key_ will be empty.
 | ||||
|     int64 start_timestamp = 0; | ||||
|     int64 end_timestamp = 0; | ||||
|     int64_t start_timestamp = 0; | ||||
|     int64_t end_timestamp = 0; | ||||
|     if (last_timestamp_key_.empty() || process_poststream_) { | ||||
|       process_poststream_ = true; | ||||
|       start_timestamp = Timestamp::PostStream().Value(); | ||||
|  | @ -481,14 +481,14 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { | |||
|   // Store a map from the keys for each stream to the timestamps for each
 | ||||
|   // key. This allows us to identify which packets to output for each stream
 | ||||
|   // for timestamps within a given time window.
 | ||||
|   std::map<std::string, std::vector<int64>> timestamps_; | ||||
|   std::map<std::string, std::vector<int64_t>> timestamps_; | ||||
|   // Store the stream with the latest timestamp in the SequenceExample.
 | ||||
|   std::string last_timestamp_key_; | ||||
|   // Store the index of the current timestamp. Will be less than
 | ||||
|   // timestamps_[last_timestamp_key_].size().
 | ||||
|   int current_timestamp_index_; | ||||
|   // Store the very first timestamp, so we output everything on the first frame.
 | ||||
|   int64 first_timestamp_seen_; | ||||
|   int64_t first_timestamp_seen_; | ||||
|   // List of keypoint names.
 | ||||
|   std::vector<std::string> keypoint_names_; | ||||
|   // Default keypoint location when missing.
 | ||||
|  |  | |||
|  | @ -54,7 +54,7 @@ class VectorToTensorFloatCalculatorTest : public ::testing::Test { | |||
|       } | ||||
|     } | ||||
| 
 | ||||
|     const int64 time = 1234; | ||||
|     const int64_t time = 1234; | ||||
|     runner_->MutableInputs()->Index(0).packets.push_back( | ||||
|         Adopt(input.release()).At(Timestamp(time))); | ||||
| 
 | ||||
|  | @ -91,7 +91,7 @@ TEST_F(VectorToTensorFloatCalculatorTest, ConvertsFromVectorFloat) { | |||
|     // 2^i can be represented exactly in floating point numbers if 'i' is small.
 | ||||
|     input->at(i) = static_cast<float>(1 << i); | ||||
|   } | ||||
|   const int64 time = 1234; | ||||
|   const int64_t time = 1234; | ||||
|   runner_->MutableInputs()->Index(0).packets.push_back( | ||||
|       Adopt(input.release()).At(Timestamp(time))); | ||||
| 
 | ||||
|  |  | |||
|  | @ -28,11 +28,8 @@ | |||
| #include "mediapipe/framework/port/ret_check.h" | ||||
| #include "mediapipe/framework/port/status.h" | ||||
| 
 | ||||
| using mediapipe::Adopt; | ||||
| using mediapipe::CalculatorBase; | ||||
| using mediapipe::ImageFrame; | ||||
| using mediapipe::PacketTypeSet; | ||||
| using mediapipe::autoflip::Border; | ||||
| 
 | ||||
| constexpr char kDetectedBorders[] = "DETECTED_BORDERS"; | ||||
| constexpr int kMinBorderDistance = 5; | ||||
|  |  | |||
|  | @ -28,16 +28,12 @@ | |||
| #include "mediapipe/framework/port/status.h" | ||||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| 
 | ||||
| using mediapipe::Adopt; | ||||
| using mediapipe::CalculatorGraphConfig; | ||||
| using mediapipe::CalculatorRunner; | ||||
| using mediapipe::ImageFormat; | ||||
| using mediapipe::ImageFrame; | ||||
| using mediapipe::Packet; | ||||
| using mediapipe::PacketTypeSet; | ||||
| using mediapipe::ParseTextProtoOrDie; | ||||
| using mediapipe::Timestamp; | ||||
| using mediapipe::autoflip::Border; | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace autoflip { | ||||
|  |  | |||
|  | @ -31,14 +31,11 @@ | |||
| #include "mediapipe/framework/port/status.h" | ||||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| 
 | ||||
| using mediapipe::Adopt; | ||||
| using mediapipe::CalculatorGraphConfig; | ||||
| using mediapipe::CalculatorRunner; | ||||
| using mediapipe::ImageFormat; | ||||
| using mediapipe::ImageFrame; | ||||
| using mediapipe::PacketTypeSet; | ||||
| using mediapipe::ParseTextProtoOrDie; | ||||
| using mediapipe::Timestamp; | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace autoflip { | ||||
|  |  | |||
|  | @ -28,8 +28,6 @@ | |||
| using mediapipe::Packet; | ||||
| using mediapipe::PacketTypeSet; | ||||
| using mediapipe::autoflip::DetectionSet; | ||||
| using mediapipe::autoflip::SalientRegion; | ||||
| using mediapipe::autoflip::SignalType; | ||||
| 
 | ||||
| constexpr char kIsShotBoundaryTag[] = "IS_SHOT_BOUNDARY"; | ||||
| constexpr char kSignalInputsTag[] = "SIGNAL"; | ||||
|  |  | |||
|  | @ -19,8 +19,6 @@ namespace mediapipe { | |||
| namespace api2 { | ||||
| namespace test { | ||||
| 
 | ||||
| using testing::ElementsAre; | ||||
| 
 | ||||
| // Returns the packet values for a vector of Packets.
 | ||||
| template <typename T> | ||||
| std::vector<T> PacketValues(const std::vector<mediapipe::Packet>& packets) { | ||||
|  |  | |||
|  | @ -310,7 +310,7 @@ class Scheduler { | |||
|   absl::Mutex state_mutex_; | ||||
| 
 | ||||
|   // Current state of the scheduler.
 | ||||
|   std::atomic<State> state_ = ATOMIC_VAR_INIT(STATE_NOT_STARTED); | ||||
|   std::atomic<State> state_ = STATE_NOT_STARTED; | ||||
| 
 | ||||
|   // True if all graph input streams are closed.
 | ||||
|   bool graph_input_streams_closed_ ABSL_GUARDED_BY(state_mutex_) = false; | ||||
|  |  | |||
|  | @ -131,7 +131,7 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { | |||
|       ABSL_EXCLUSIVE_LOCKS_REQUIRED(erase_mutex_) { | ||||
|     // Record the most recent first kept timestamp on any stream.
 | ||||
|     for (const auto& stream : input_stream_managers_) { | ||||
|       int32 queue_size = (stream->QueueSize() >= trigger_queue_size_) | ||||
|       int32_t queue_size = (stream->QueueSize() >= trigger_queue_size_) | ||||
|                                ? target_queue_size_ | ||||
|                                : trigger_queue_size_ - 1; | ||||
|       if (stream->QueueSize() > queue_size) { | ||||
|  | @ -214,8 +214,8 @@ class FixedSizeInputStreamHandler : public DefaultInputStreamHandler { | |||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   int32 trigger_queue_size_; | ||||
|   int32 target_queue_size_; | ||||
|   int32_t trigger_queue_size_; | ||||
|   int32_t target_queue_size_; | ||||
|   bool fixed_min_size_; | ||||
|   // Indicates that GetNodeReadiness has returned kReadyForProcess once, and
 | ||||
|   // the corresponding call to FillInputSet has not yet completed.
 | ||||
|  |  | |||
|  | @ -30,15 +30,15 @@ namespace mediapipe { | |||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| const int64 kMaxPacketId = 100; | ||||
| const int64 kSlowCalculatorRate = 10; | ||||
| const int64_t kMaxPacketId = 100; | ||||
| const int64_t kSlowCalculatorRate = 10; | ||||
| 
 | ||||
| // Rate limiter for TestSlowCalculator.
 | ||||
| ABSL_CONST_INIT absl::Mutex g_source_mutex(absl::kConstInit); | ||||
| int64 g_source_counter ABSL_GUARDED_BY(g_source_mutex); | ||||
| int64_t g_source_counter ABSL_GUARDED_BY(g_source_mutex); | ||||
| 
 | ||||
| // Rate limiter for TestSourceCalculator.
 | ||||
| int64 g_slow_counter ABSL_GUARDED_BY(g_source_mutex); | ||||
| int64_t g_slow_counter ABSL_GUARDED_BY(g_source_mutex); | ||||
| 
 | ||||
| // Flag that indicates that the source is done.
 | ||||
| bool g_source_done ABSL_GUARDED_BY(g_source_mutex); | ||||
|  | @ -47,7 +47,7 @@ class TestSourceCalculator : public CalculatorBase { | |||
|  public: | ||||
|   TestSourceCalculator() : current_packet_id_(0) {} | ||||
|   static absl::Status GetContract(CalculatorContract* cc) { | ||||
|     cc->Outputs().Index(0).Set<int64>(); | ||||
|     cc->Outputs().Index(0).Set<int64_t>(); | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
|   absl::Status Open(CalculatorContext* cc) override { | ||||
|  | @ -62,7 +62,7 @@ class TestSourceCalculator : public CalculatorBase { | |||
|       g_source_done = true; | ||||
|       return tool::StatusStop(); | ||||
|     } | ||||
|     cc->Outputs().Index(0).Add(new int64(0), Timestamp(current_packet_id_)); | ||||
|     cc->Outputs().Index(0).Add(new int64_t(0), Timestamp(current_packet_id_)); | ||||
|     ++current_packet_id_; | ||||
|     { | ||||
|       absl::MutexLock lock(&g_source_mutex); | ||||
|  | @ -78,7 +78,7 @@ class TestSourceCalculator : public CalculatorBase { | |||
|     return g_source_counter <= kSlowCalculatorRate * g_slow_counter || | ||||
|            g_source_counter <= 1; | ||||
|   } | ||||
|   int64 current_packet_id_; | ||||
|   int64_t current_packet_id_; | ||||
| }; | ||||
| 
 | ||||
| REGISTER_CALCULATOR(TestSourceCalculator); | ||||
|  | @ -87,8 +87,8 @@ class TestSlowCalculator : public CalculatorBase { | |||
|  public: | ||||
|   TestSlowCalculator() = default; | ||||
|   static absl::Status GetContract(CalculatorContract* cc) { | ||||
|     cc->Inputs().Index(0).Set<int64>(); | ||||
|     cc->Outputs().Index(0).Set<int64>(); | ||||
|     cc->Inputs().Index(0).Set<int64_t>(); | ||||
|     cc->Outputs().Index(0).Set<int64_t>(); | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
|   absl::Status Open(CalculatorContext* cc) override { | ||||
|  | @ -97,7 +97,7 @@ class TestSlowCalculator : public CalculatorBase { | |||
|     return absl::OkStatus(); | ||||
|   } | ||||
|   absl::Status Process(CalculatorContext* cc) override { | ||||
|     cc->Outputs().Index(0).Add(new int64(0), | ||||
|     cc->Outputs().Index(0).Add(new int64_t(0), | ||||
|                                cc->Inputs().Index(0).Value().Timestamp()); | ||||
|     { | ||||
|       absl::MutexLock lock(&g_source_mutex); | ||||
|  | @ -118,8 +118,9 @@ class TestSlowCalculator : public CalculatorBase { | |||
| REGISTER_CALCULATOR(TestSlowCalculator); | ||||
| 
 | ||||
| // Return the values of the timestamps of a vector of Packets.
 | ||||
| static std::vector<int64> TimestampValues(const std::vector<Packet>& packets) { | ||||
|   std::vector<int64> result; | ||||
| static std::vector<int64_t> TimestampValues( | ||||
|     const std::vector<Packet>& packets) { | ||||
|   std::vector<int64_t> result; | ||||
|   for (const Packet& p : packets) { | ||||
|     result.push_back(p.Timestamp().Value()); | ||||
|   } | ||||
|  | @ -174,7 +175,7 @@ TEST_P(FixedSizeInputStreamHandlerTest, DropsPackets) { | |||
|   // consumed.  In this way, the TestSlowCalculator consumes and outputs only
 | ||||
|   // every tenth packet.
 | ||||
|   EXPECT_EQ(output_packets.size(), 11); | ||||
|   std::vector<int64> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99}; | ||||
|   std::vector<int64_t> expected_ts = {0, 9, 19, 29, 39, 49, 59, 69, 79, 89, 99}; | ||||
|   EXPECT_THAT(TimestampValues(output_packets), | ||||
|               testing::ContainerEq(expected_ts)); | ||||
| } | ||||
|  | @ -344,18 +345,18 @@ TEST_P(FixedSizeInputStreamHandlerTest, LateArrivalDrop) { | |||
| 
 | ||||
|   if (GetParam()) { | ||||
|     EXPECT_THAT(TimestampValues(output_packets[0]), | ||||
|                 testing::ContainerEq(std::vector<int64>{1, 2, 3, 4, 5, 6})); | ||||
|                 testing::ContainerEq(std::vector<int64_t>{1, 2, 3, 4, 5, 6})); | ||||
|     EXPECT_THAT(TimestampValues(output_packets[1]), | ||||
|                 testing::ContainerEq(std::vector<int64>{3, 4, 5, 6, 7})); | ||||
|                 testing::ContainerEq(std::vector<int64_t>{3, 4, 5, 6, 7})); | ||||
|     EXPECT_THAT(TimestampValues(output_packets[2]), | ||||
|                 testing::ContainerEq(std::vector<int64>{4, 5, 6, 7})); | ||||
|                 testing::ContainerEq(std::vector<int64_t>{4, 5, 6, 7})); | ||||
|   } else { | ||||
|     EXPECT_THAT(TimestampValues(output_packets[0]), | ||||
|                 testing::ContainerEq(std::vector<int64>{5, 6})); | ||||
|                 testing::ContainerEq(std::vector<int64_t>{5, 6})); | ||||
|     EXPECT_THAT(TimestampValues(output_packets[1]), | ||||
|                 testing::ContainerEq(std::vector<int64>{5, 6, 7})); | ||||
|                 testing::ContainerEq(std::vector<int64_t>{5, 6, 7})); | ||||
|     EXPECT_THAT(TimestampValues(output_packets[2]), | ||||
|                 testing::ContainerEq(std::vector<int64>{5, 6, 7})); | ||||
|                 testing::ContainerEq(std::vector<int64_t>{5, 6, 7})); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -27,10 +27,6 @@ namespace options_field_util { | |||
| 
 | ||||
| using ::mediapipe::proto_ns::internal::WireFormatLite; | ||||
| using FieldType = WireFormatLite::FieldType; | ||||
| using ::mediapipe::proto_ns::io::ArrayInputStream; | ||||
| using ::mediapipe::proto_ns::io::CodedInputStream; | ||||
| using ::mediapipe::proto_ns::io::CodedOutputStream; | ||||
| using ::mediapipe::proto_ns::io::StringOutputStream; | ||||
| 
 | ||||
| // Utility functions for OptionsFieldUtil.
 | ||||
| namespace { | ||||
|  |  | |||
|  | @ -454,8 +454,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> { | |||
|   // Number of glFinish calls completed on the GL thread.
 | ||||
|   // Changes should be guarded by mutex_. However, we use simple atomic
 | ||||
|   // loads for efficiency on the fast path.
 | ||||
|   std::atomic<int64_t> gl_finish_count_ = ATOMIC_VAR_INIT(0); | ||||
|   std::atomic<int64_t> gl_finish_count_target_ = ATOMIC_VAR_INIT(0); | ||||
|   std::atomic<int64_t> gl_finish_count_ = 0; | ||||
|   std::atomic<int64_t> gl_finish_count_target_ = 0; | ||||
| 
 | ||||
|   GlContext* context_waiting_on_ ABSL_GUARDED_BY(mutex_) = nullptr; | ||||
| 
 | ||||
|  |  | |||
|  | @ -67,53 +67,14 @@ absl::Status GlContext::CreateContextInternal( | |||
|   // TODO: Investigate this option in more detail, esp. on Safari.
 | ||||
|   attrs.preserveDrawingBuffer = 0; | ||||
| 
 | ||||
|   // Since the Emscripten canvas target finding function is visible from here,
 | ||||
|   // we hijack findCanvasEventTarget directly for enforcing old Module.canvas
 | ||||
|   // behavior if the user desires, falling back to the new DOM element CSS
 | ||||
|   // selector behavior next if that is specified, and finally just allowing the
 | ||||
|   // lookup to proceed on a null target.
 | ||||
|   // TODO: Ensure this works with all options (in particular,
 | ||||
|   //   multithreading options, like the special-case combination of USE_PTHREADS
 | ||||
|   //   and OFFSCREEN_FRAMEBUFFER)
 | ||||
|   // clang-format off
 | ||||
|   EM_ASM( | ||||
|     let init_once = true; | ||||
|     if (init_once) { | ||||
|       const cachedFindCanvasEventTarget = findCanvasEventTarget; | ||||
| 
 | ||||
|       if (typeof cachedFindCanvasEventTarget !== 'function') { | ||||
|         if (typeof console !== 'undefined') { | ||||
|           console.error('Expected Emscripten global function ' | ||||
|               + '"findCanvasEventTarget" not found. WebGL context creation ' | ||||
|               + 'may fail.'); | ||||
|         } | ||||
|         return; | ||||
|       } | ||||
| 
 | ||||
|       findCanvasEventTarget = function(target) { | ||||
|         if (target == 0) { | ||||
|           if (Module && Module.canvas) { | ||||
|             return Module.canvas; | ||||
|           } else if (Module && Module.canvasCssSelector) { | ||||
|             return cachedFindCanvasEventTarget(Module.canvasCssSelector); | ||||
|           } | ||||
|           if (typeof console !== 'undefined') { | ||||
|             console.warn('Module properties canvas and canvasCssSelector not ' + | ||||
|                          'found during WebGL context creation.'); | ||||
|           } | ||||
|         } | ||||
|         // We still go through with the find attempt, although for most use
 | ||||
|         // cases it will not succeed, just in case the user does want to fall-
 | ||||
|         // back.
 | ||||
|         return cachedFindCanvasEventTarget(target); | ||||
|       };  // NOLINT: Necessary semicolon.
 | ||||
|       init_once = false; | ||||
|     } | ||||
|   ); | ||||
|   // clang-format on
 | ||||
| 
 | ||||
|   // Quick patch for -s DISABLE_DEPRECATED_FIND_EVENT_TARGET_BEHAVIOR so it also
 | ||||
|   // looks for our #canvas target in Module.canvas, where we expect it to be.
 | ||||
|   // -s OFFSCREENCANVAS_SUPPORT=1 will no longer work with this under the new
 | ||||
|   // event target behavior, but it was never supposed to be tapping into our
 | ||||
|   // canvas anyways. See b/278155946 for more background.
 | ||||
|   EM_ASM({ specialHTMLTargets["#canvas"] = Module.canvas; }); | ||||
|   EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_handle = | ||||
|       emscripten_webgl_create_context(nullptr, &attrs); | ||||
|       emscripten_webgl_create_context("#canvas", &attrs); | ||||
| 
 | ||||
|   // Check for failure
 | ||||
|   if (context_handle <= 0) { | ||||
|  |  | |||
|  | @ -64,7 +64,7 @@ std::unique_ptr<GlTextureBuffer> GlTextureBuffer::Create( | |||
|   int actual_ws = image_frame.WidthStep(); | ||||
|   int alignment = 0; | ||||
|   std::unique_ptr<ImageFrame> temp; | ||||
|   const uint8* data = image_frame.PixelData(); | ||||
|   const uint8_t* data = image_frame.PixelData(); | ||||
| 
 | ||||
|   // Let's see if the pixel data is tightly aligned to one of the alignments
 | ||||
|   // supported by OpenGL, preferring 4 if possible since it's the default.
 | ||||
|  |  | |||
|  | @ -167,7 +167,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, | |||
|                                                    GpuBufferFormat format) { | ||||
|   libyuv::FourCC fourcc = FourCCForGpuBufferFormat(format); | ||||
|   int y_stride = std::ceil(1.0f * width / kDefaultDataAligment); | ||||
|   auto y_data = std::make_unique<uint8[]>(y_stride * height); | ||||
|   auto y_data = std::make_unique<uint8_t[]>(y_stride * height); | ||||
|   switch (fourcc) { | ||||
|     case libyuv::FOURCC_NV12: | ||||
|     case libyuv::FOURCC_NV21: { | ||||
|  | @ -175,7 +175,7 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, | |||
|       int uv_width = 2 * std::ceil(0.5f * width); | ||||
|       int uv_height = std::ceil(0.5f * height); | ||||
|       int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); | ||||
|       auto uv_data = std::make_unique<uint8[]>(uv_stride * uv_height); | ||||
|       auto uv_data = std::make_unique<uint8_t[]>(uv_stride * uv_height); | ||||
|       yuv_image_ = std::make_shared<YUVImage>( | ||||
|           fourcc, std::move(y_data), y_stride, std::move(uv_data), uv_stride, | ||||
|           nullptr, 0, width, height); | ||||
|  | @ -187,8 +187,8 @@ GpuBufferStorageYuvImage::GpuBufferStorageYuvImage(int width, int height, | |||
|       int uv_width = std::ceil(0.5f * width); | ||||
|       int uv_height = std::ceil(0.5f * height); | ||||
|       int uv_stride = std::ceil(1.0f * uv_width / kDefaultDataAligment); | ||||
|       auto u_data = std::make_unique<uint8[]>(uv_stride * uv_height); | ||||
|       auto v_data = std::make_unique<uint8[]>(uv_stride * uv_height); | ||||
|       auto u_data = std::make_unique<uint8_t[]>(uv_stride * uv_height); | ||||
|       auto v_data = std::make_unique<uint8_t[]>(uv_stride * uv_height); | ||||
|       yuv_image_ = std::make_shared<YUVImage>( | ||||
|           fourcc, std::move(y_data), y_stride, std::move(u_data), uv_stride, | ||||
|           std::move(v_data), uv_stride, width, height); | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ import csv | |||
| import filecmp | ||||
| import os | ||||
| import tempfile | ||||
| import unittest | ||||
| from unittest import mock as unittest_mock | ||||
| 
 | ||||
| import tensorflow as tf | ||||
|  | @ -24,6 +25,7 @@ from mediapipe.model_maker.python.text import text_classifier | |||
| from mediapipe.tasks.python.test import test_utils | ||||
| 
 | ||||
| 
 | ||||
| @unittest.skip('b/275624089') | ||||
| class TextClassifierTest(tf.test.TestCase): | ||||
| 
 | ||||
|   _AVERAGE_WORD_EMBEDDING_JSON_FILE = ( | ||||
|  |  | |||
|  | @ -175,11 +175,7 @@ py_test( | |||
|     data = [":testdata"], | ||||
|     tags = ["requires-net:external"], | ||||
|     deps = [ | ||||
|         ":dataset", | ||||
|         ":hyperparameters", | ||||
|         ":model_spec", | ||||
|         ":object_detector", | ||||
|         ":object_detector_options", | ||||
|         ":object_detector_import", | ||||
|         "//mediapipe/tasks/python/test:test_utils", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -19,11 +19,7 @@ from unittest import mock as unittest_mock | |||
| from absl.testing import parameterized | ||||
| import tensorflow as tf | ||||
| 
 | ||||
| from mediapipe.model_maker.python.vision.object_detector import dataset | ||||
| from mediapipe.model_maker.python.vision.object_detector import hyperparameters | ||||
| from mediapipe.model_maker.python.vision.object_detector import model_spec as ms | ||||
| from mediapipe.model_maker.python.vision.object_detector import object_detector | ||||
| from mediapipe.model_maker.python.vision.object_detector import object_detector_options | ||||
| from mediapipe.model_maker.python.vision import object_detector | ||||
| from mediapipe.tasks.python.test import test_utils as task_test_utils | ||||
| 
 | ||||
| 
 | ||||
|  | @ -33,7 +29,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): | |||
|     super().setUp() | ||||
|     dataset_folder = task_test_utils.get_test_data_path('coco_data') | ||||
|     cache_dir = self.create_tempdir() | ||||
|     self.data = dataset.Dataset.from_coco_folder( | ||||
|     self.data = object_detector.Dataset.from_coco_folder( | ||||
|         dataset_folder, cache_dir=cache_dir | ||||
|     ) | ||||
|     # Mock tempfile.gettempdir() to be unique for each test to avoid race | ||||
|  | @ -48,15 +44,16 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): | |||
|     self.addCleanup(mock_gettempdir.stop) | ||||
| 
 | ||||
|   def test_object_detector(self): | ||||
|     hparams = hyperparameters.HParams( | ||||
|     hparams = object_detector.HParams( | ||||
|         epochs=1, | ||||
|         batch_size=2, | ||||
|         learning_rate=0.9, | ||||
|         shuffle=False, | ||||
|         export_dir=self.create_tempdir(), | ||||
|     ) | ||||
|     options = object_detector_options.ObjectDetectorOptions( | ||||
|         supported_model=ms.SupportedModels.MOBILENET_V2, hparams=hparams | ||||
|     options = object_detector.ObjectDetectorOptions( | ||||
|         supported_model=object_detector.SupportedModels.MOBILENET_V2, | ||||
|         hparams=hparams, | ||||
|     ) | ||||
|     # Test `create`` | ||||
|     model = object_detector.ObjectDetector.create( | ||||
|  | @ -79,7 +76,7 @@ class ObjectDetectorTest(tf.test.TestCase, parameterized.TestCase): | |||
|     self.assertGreater(os.path.getsize(output_metadata_file), 0) | ||||
| 
 | ||||
|     # Test `quantization_aware_training` | ||||
|     qat_hparams = hyperparameters.QATHParams( | ||||
|     qat_hparams = object_detector.QATHParams( | ||||
|         learning_rate=0.9, | ||||
|         batch_size=2, | ||||
|         epochs=1, | ||||
|  |  | |||
|  | @ -298,6 +298,7 @@ cc_library( | |||
|         ":tensors_to_objects_calculator_cc_proto", | ||||
|         "//mediapipe/framework:calculator_framework", | ||||
|         "//mediapipe/framework/deps:file_path", | ||||
|         "//mediapipe/framework/formats:tensor", | ||||
|         "//mediapipe/framework/port:opencv_core", | ||||
|         "//mediapipe/framework/port:ret_check", | ||||
|         "@com_google_absl//absl/memory", | ||||
|  |  | |||
|  | @ -91,8 +91,8 @@ absl::Status FrameAnnotationToTimedBoxListCalculator::Process( | |||
|       TimedBoxProto* added_box = output_objects->add_box(); | ||||
|       ComputeBoundingRect(key_points, added_box); | ||||
|       added_box->set_id(annotation.object_id()); | ||||
|       const int64 time_msec = | ||||
|           static_cast<int64>(std::round(frame_annotation.timestamp() / 1000)); | ||||
|       const int64_t time_msec = | ||||
|           static_cast<int64_t>(std::round(frame_annotation.timestamp() / 1000)); | ||||
|       added_box->set_time_msec(time_msec); | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -24,8 +24,8 @@ namespace mediapipe { | |||
| 
 | ||||
| void FrameAnnotationTracker::AddDetectionResult( | ||||
|     const FrameAnnotation& frame_annotation) { | ||||
|   const int64 time_us = | ||||
|       static_cast<int64>(std::round(frame_annotation.timestamp())); | ||||
|   const int64_t time_us = | ||||
|       static_cast<int64_t>(std::round(frame_annotation.timestamp())); | ||||
|   for (const auto& object_annotation : frame_annotation.annotations()) { | ||||
|     detected_objects_[time_us + object_annotation.object_id()] = | ||||
|         object_annotation; | ||||
|  | @ -37,7 +37,7 @@ FrameAnnotation FrameAnnotationTracker::ConsolidateTrackingResult( | |||
|     absl::flat_hash_set<int>* cancel_object_ids) { | ||||
|   CHECK(cancel_object_ids != nullptr); | ||||
|   FrameAnnotation frame_annotation; | ||||
|   std::vector<int64> keys_to_be_deleted; | ||||
|   std::vector<int64_t> keys_to_be_deleted; | ||||
|   for (const auto& detected_obj : detected_objects_) { | ||||
|     const int object_id = detected_obj.second.object_id(); | ||||
|     if (cancel_object_ids->contains(object_id)) { | ||||
|  |  | |||
|  | @ -76,7 +76,7 @@ class TfLiteTensorsToObjectsCalculator : public CalculatorBase { | |||
|   // In a single MediaPipe session, the IDs are unique.
 | ||||
|   // Also assign timestamp for the FrameAnnotation to be the input packet
 | ||||
|   // timestamp.
 | ||||
|   void AssignObjectIdAndTimestamp(int64 timestamp_us, | ||||
|   void AssignObjectIdAndTimestamp(int64_t timestamp_us, | ||||
|                                   FrameAnnotation* annotation); | ||||
| 
 | ||||
|   int num_classes_ = 0; | ||||
|  | @ -207,7 +207,7 @@ void TfLiteTensorsToObjectsCalculator::Project3DTo2D( | |||
| } | ||||
| 
 | ||||
| void TfLiteTensorsToObjectsCalculator::AssignObjectIdAndTimestamp( | ||||
|     int64 timestamp_us, FrameAnnotation* annotation) { | ||||
|     int64_t timestamp_us, FrameAnnotation* annotation) { | ||||
|   for (auto& ann : *annotation->mutable_annotations()) { | ||||
|     ann.set_object_id(GetNextObjectId()); | ||||
|   } | ||||
|  |  | |||
|  | @ -37,7 +37,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/common.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/category.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -157,7 +157,7 @@ void CheckStreamingModeResults(std::vector<AudioClassifierResult> outputs) { | |||
|   } | ||||
| } | ||||
| 
 | ||||
| class CreateFromOptionsTest : public tflite_shims::testing::Test {}; | ||||
| class CreateFromOptionsTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) { | ||||
|   auto options = std::make_unique<AudioClassifierOptions>(); | ||||
|  | @ -270,7 +270,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { | |||
|                   MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); | ||||
| } | ||||
| 
 | ||||
| class ClassifyTest : public tflite_shims::testing::Test {}; | ||||
| class ClassifyTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ClassifyTest, Succeeds) { | ||||
|   auto audio_buffer = GetAudioData(k16kTestWavFilename); | ||||
|  | @ -467,7 +467,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { | |||
|   } | ||||
| } | ||||
| 
 | ||||
| class ClassifyAsyncTest : public tflite_shims::testing::Test {}; | ||||
| class ClassifyAsyncTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ClassifyAsyncTest, Succeeds) { | ||||
|   constexpr int kSampleRateHz = 48000; | ||||
|  |  | |||
|  | @ -36,7 +36,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/audio/utils/test_utils.h" | ||||
| #include "mediapipe/tasks/cc/common.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/embedding_result.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -66,7 +66,7 @@ Matrix GetAudioData(absl::string_view filename) { | |||
|   return matrix_mapping.matrix(); | ||||
| } | ||||
| 
 | ||||
| class CreateFromOptionsTest : public tflite_shims::testing::Test {}; | ||||
| class CreateFromOptionsTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { | ||||
|   auto audio_embedder = | ||||
|  | @ -124,7 +124,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallbackInAudioStreamMode) { | |||
|                   MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); | ||||
| } | ||||
| 
 | ||||
| class EmbedTest : public tflite_shims::testing::Test {}; | ||||
| class EmbedTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(EmbedTest, SucceedsWithSilentAudio) { | ||||
|   auto options = std::make_unique<AudioEmbedderOptions>(); | ||||
|  | @ -187,7 +187,7 @@ TEST_F(EmbedTest, SucceedsWithDifferentAudios) { | |||
|   MP_EXPECT_OK(audio_embedder->Close()); | ||||
| } | ||||
| 
 | ||||
| class EmbedAsyncTest : public tflite_shims::testing::Test { | ||||
| class EmbedAsyncTest : public tflite::testing::Test { | ||||
|  protected: | ||||
|   void RunAudioEmbedderInStreamMode(std::string audio_file_name, | ||||
|                                     int sample_rate_hz, | ||||
|  |  | |||
|  | @ -47,7 +47,7 @@ cc_test_with_tflite( | |||
|     data = ["//mediapipe/tasks/testdata/audio:test_models"], | ||||
|     tflite_deps = [ | ||||
|         "//mediapipe/tasks/cc/core:model_resources", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":audio_tensor_specs", | ||||
|  |  | |||
|  | @ -34,7 +34,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/core/model_resources.h" | ||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -52,7 +52,7 @@ constexpr char kModelWithMetadata[] = | |||
|     "yamnet_audio_classifier_with_metadata.tflite"; | ||||
| constexpr char kModelWithoutMetadata[] = "model_without_metadata.tflite"; | ||||
| 
 | ||||
| class AudioTensorSpecsTest : public tflite_shims::testing::Test {}; | ||||
| class AudioTensorSpecsTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(AudioTensorSpecsTest, | ||||
|        BuildInputAudioTensorSpecsWithoutMetdataOptionsFails) { | ||||
|  |  | |||
|  | @ -63,7 +63,7 @@ cc_test( | |||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@com_google_absl//absl/strings:str_format", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
|  | @ -232,6 +232,6 @@ cc_test( | |||
|         "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -33,7 +33,7 @@ limitations under the License. | |||
| #include "mediapipe/framework/timestamp.h" | ||||
| #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace { | ||||
|  | @ -66,8 +66,7 @@ ClassificationList MakeClassificationList(int class_index) { | |||
|       class_index)); | ||||
| } | ||||
| 
 | ||||
| class ClassificationAggregationCalculatorTest | ||||
|     : public tflite_shims::testing::Test { | ||||
| class ClassificationAggregationCalculatorTest : public tflite::testing::Test { | ||||
|  protected: | ||||
|   absl::StatusOr<OutputStreamPoller> BuildGraph( | ||||
|       bool connect_timestamps = false) { | ||||
|  |  | |||
|  | @ -31,7 +31,7 @@ limitations under the License. | |||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/framework/timestamp.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace { | ||||
|  | @ -52,7 +52,7 @@ constexpr char kTimestampsName[] = "timestamps_in"; | |||
| constexpr char kTimestampedEmbeddingsTag[] = "TIMESTAMPED_EMBEDDINGS"; | ||||
| constexpr char kTimestampedEmbeddingsName[] = "timestamped_embeddings_out"; | ||||
| 
 | ||||
| class EmbeddingAggregationCalculatorTest : public tflite_shims::testing::Test { | ||||
| class EmbeddingAggregationCalculatorTest : public tflite::testing::Test { | ||||
|  protected: | ||||
|   absl::StatusOr<OutputStreamPoller> BuildGraph(bool connect_timestamps) { | ||||
|     Graph graph; | ||||
|  |  | |||
|  | @ -66,7 +66,7 @@ using ::mediapipe::tasks::core::ModelResources; | |||
| using ::mediapipe::tasks::metadata::ModelMetadataExtractor; | ||||
| using ::tflite::ProcessUnit; | ||||
| using ::tflite::TensorMetadata; | ||||
| using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; | ||||
| using LabelItems = mediapipe::proto_ns::Map<int64_t, ::mediapipe::LabelMapItem>; | ||||
| using TensorsSource = mediapipe::api2::builder::Source<std::vector<Tensor>>; | ||||
| 
 | ||||
| constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); | ||||
|  |  | |||
|  | @ -49,7 +49,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/core/model_resources.h" | ||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "mediapipe/util/label_map.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -101,7 +101,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel( | |||
|                                 std::move(external_file)); | ||||
| } | ||||
| 
 | ||||
| class ConfigureTest : public tflite_shims::testing::Test {}; | ||||
| class ConfigureTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ConfigureTest, FailsWithInvalidMaxResults) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -417,7 +417,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { | |||
|                                )pb"))); | ||||
| } | ||||
| 
 | ||||
| class PostprocessingTest : public tflite_shims::testing::Test { | ||||
| class PostprocessingTest : public tflite::testing::Test { | ||||
|  protected: | ||||
|   absl::StatusOr<OutputStreamPoller> BuildGraph( | ||||
|       absl::string_view model_name, const proto::ClassifierOptions& options, | ||||
|  | @ -520,7 +520,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { | |||
|       auto poller, | ||||
|       BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); | ||||
|   // Build input tensors.
 | ||||
|   std::vector<uint8> tensor(kMobileNetNumClasses, 0); | ||||
|   std::vector<uint8_t> tensor(kMobileNetNumClasses, 0); | ||||
|   tensor[1] = 18; | ||||
|   tensor[2] = 16; | ||||
| 
 | ||||
|  | @ -552,7 +552,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { | |||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); | ||||
|   // Build input tensors.
 | ||||
|   std::vector<uint8> tensor(kMobileNetNumClasses, 0); | ||||
|   std::vector<uint8_t> tensor(kMobileNetNumClasses, 0); | ||||
|   tensor[1] = 12; | ||||
|   tensor[2] = 14; | ||||
|   tensor[3] = 16; | ||||
|  | @ -589,7 +589,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { | |||
|       auto poller, | ||||
|       BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); | ||||
|   // Build input tensors.
 | ||||
|   std::vector<uint8> tensor(kMobileNetNumClasses, 0); | ||||
|   std::vector<uint8_t> tensor(kMobileNetNumClasses, 0); | ||||
|   tensor[1] = 12; | ||||
|   tensor[2] = 14; | ||||
|   tensor[3] = 16; | ||||
|  | @ -677,11 +677,11 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { | |||
|       auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, | ||||
|                               /*connect_timestamps=*/true)); | ||||
|   // Build input tensors.
 | ||||
|   std::vector<uint8> tensor_0(kMobileNetNumClasses, 0); | ||||
|   std::vector<uint8_t> tensor_0(kMobileNetNumClasses, 0); | ||||
|   tensor_0[1] = 12; | ||||
|   tensor_0[2] = 14; | ||||
|   tensor_0[3] = 16; | ||||
|   std::vector<uint8> tensor_1(kMobileNetNumClasses, 0); | ||||
|   std::vector<uint8_t> tensor_1(kMobileNetNumClasses, 0); | ||||
|   tensor_1[5] = 12; | ||||
|   tensor_1[6] = 14; | ||||
|   tensor_1[7] = 16; | ||||
|  |  | |||
|  | @ -39,7 +39,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/components/processors/proto/embedding_postprocessing_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | ||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -86,7 +86,7 @@ absl::StatusOr<std::unique_ptr<ModelResources>> CreateModelResourcesForModel( | |||
|                                 std::move(external_file)); | ||||
| } | ||||
| 
 | ||||
| class ConfigureTest : public tflite_shims::testing::Test {}; | ||||
| class ConfigureTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -153,7 +153,7 @@ TEST_F(ConfigureTest, SucceedsWithFloatModelWithMetadata) { | |||
|                    has_quantized_outputs: false)pb"))); | ||||
| } | ||||
| 
 | ||||
| class PostprocessingTest : public tflite_shims::testing::Test { | ||||
| class PostprocessingTest : public tflite::testing::Test { | ||||
|  protected: | ||||
|   absl::StatusOr<OutputStreamPoller> BuildGraph( | ||||
|       absl::string_view model_name, const proto::EmbedderOptions& options, | ||||
|  |  | |||
|  | @ -37,7 +37,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "mediapipe/tasks/cc/core/task_runner.h" | ||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -125,7 +125,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner( | |||
|   return TaskRunner::Create(graph.GetConfig()); | ||||
| } | ||||
| 
 | ||||
| class ConfigureTest : public tflite_shims::testing::Test {}; | ||||
| class ConfigureTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ConfigureTest, SucceedsWithQuantizedModelWithMetadata) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  |  | |||
|  | @ -78,6 +78,7 @@ cc_library( | |||
|     hdrs = ["mediapipe_builtin_op_resolver.h"], | ||||
|     deps = [ | ||||
|         "//mediapipe/tasks/cc/text/custom_ops/ragged:ragged_tensor_to_tensor_tflite", | ||||
|         "//mediapipe/tasks/cc/text/custom_ops/sentencepiece:sentencepiece_tokenizer_tflite", | ||||
|         "//mediapipe/tasks/cc/text/language_detector/custom_ops:kmeans_embedding_lookup", | ||||
|         "//mediapipe/tasks/cc/text/language_detector/custom_ops:ngram_hash", | ||||
|         "//mediapipe/util/tflite/operations:landmarks_to_transform_matrix", | ||||
|  | @ -128,9 +129,9 @@ cc_library_with_tflite( | |||
|     srcs = ["model_resources.cc"], | ||||
|     hdrs = ["model_resources.h"], | ||||
|     tflite_deps = [ | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:verifier", | ||||
|         "@org_tensorflow//tensorflow/lite:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", | ||||
|         "@org_tensorflow//tensorflow/lite/tools:verifier", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":external_file_handler", | ||||
|  | @ -159,9 +160,9 @@ cc_test_with_tflite( | |||
|     ], | ||||
|     tflite_deps = [ | ||||
|         ":model_resources", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite:framework_stable", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":utils", | ||||
|  | @ -186,7 +187,7 @@ cc_library_with_tflite( | |||
|     hdrs = ["model_resources_cache.h"], | ||||
|     tflite_deps = [ | ||||
|         ":model_resources", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", | ||||
|         "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":model_asset_bundle_resources", | ||||
|  | @ -233,7 +234,7 @@ cc_test_with_tflite( | |||
|         ":model_resources", | ||||
|         ":model_resources_cache", | ||||
|         ":model_resources_calculator", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//mediapipe/framework/port:gtest_main", | ||||
|  | @ -284,7 +285,7 @@ cc_test_with_tflite( | |||
|         ":task_runner", | ||||
|         ":model_resources", | ||||
|         ":model_resources_cache", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//mediapipe/calculators/core:pass_through_calculator", | ||||
|  | @ -317,6 +318,9 @@ cc_library( | |||
|         ":model_resources", | ||||
|         ":task_runner", | ||||
|         ":utils", | ||||
|         "//mediapipe/framework:calculator_cc_proto", | ||||
|         "//mediapipe/framework/port:requires", | ||||
|         "//mediapipe/framework/port:status", | ||||
|         "//mediapipe/tasks/cc:common", | ||||
|         "//mediapipe/tasks/cc/core/proto:base_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", | ||||
|  |  | |||
|  | @ -16,6 +16,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/ragged/ragged_tensor_to_tensor_tflite.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/kmeans_embedding_lookup.h" | ||||
| #include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" | ||||
| #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" | ||||
|  | @ -51,6 +52,8 @@ MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() { | |||
|   AddCustom("KmeansEmbeddingLookup", | ||||
|             mediapipe::tflite_operations::Register_KmeansEmbeddingLookup()); | ||||
|   // For the UniversalSentenceEncoder model.
 | ||||
|   AddCustom("TFSentencepieceTokenizeOp", | ||||
|             mediapipe::tflite_operations::Register_SENTENCEPIECE_TOKENIZER()); | ||||
|   AddCustom("RaggedTensorToTensor", | ||||
|             mediapipe::tflite_operations::Register_RAGGED_TENSOR_TO_TENSOR()); | ||||
| } | ||||
|  |  | |||
|  | @ -37,8 +37,8 @@ limitations under the License. | |||
| #include "mediapipe/util/tflite/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/model_builder.h" | ||||
| #include "tensorflow/lite/core/shims/cc/tools/verifier.h" | ||||
| #include "tensorflow/lite/model_builder.h" | ||||
| #include "tensorflow/lite/tools/verifier.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -52,7 +52,7 @@ using ::mediapipe::tasks::metadata::ModelMetadataExtractor; | |||
| 
 | ||||
| bool ModelResources::Verifier::Verify(const char* data, int length, | ||||
|                                       tflite::ErrorReporter* reporter) { | ||||
|   return tflite_shims::Verify(data, length, reporter); | ||||
|   return tflite::Verify(data, length, reporter); | ||||
| } | ||||
| 
 | ||||
| ModelResources::ModelResources(const std::string& tag, | ||||
|  | @ -124,7 +124,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() { | |||
|   // and that it uses only operators that are supported by the OpResolver
 | ||||
|   // that was passed to the ModelResources constructor, and then builds
 | ||||
|   // the model from the buffer.
 | ||||
|   auto model = tflite_shims::FlatBufferModel::VerifyAndBuildFromBuffer( | ||||
|   auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( | ||||
|       buffer_data, buffer_size, &verifier_, &error_reporter_); | ||||
|   if (model == nullptr) { | ||||
|     static constexpr char kInvalidFlatbufferMessage[] = | ||||
|  | @ -151,8 +151,7 @@ absl::Status ModelResources::BuildModelFromExternalFileProto() { | |||
|   } | ||||
| 
 | ||||
|   model_packet_ = MakePacket<ModelPtr>( | ||||
|       model.release(), | ||||
|       [](tflite_shims::FlatBufferModel* model) { delete model; }); | ||||
|       model.release(), [](tflite::FlatBufferModel* model) { delete model; }); | ||||
|   ASSIGN_OR_RETURN(auto model_metadata_extractor, | ||||
|                    metadata::ModelMetadataExtractor::CreateFromModelBuffer( | ||||
|                        buffer_data, buffer_size)); | ||||
|  |  | |||
|  | @ -32,10 +32,10 @@ limitations under the License. | |||
| #include "mediapipe/util/tflite/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/error_reporter.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/kernels/register.h" | ||||
| #include "tensorflow/lite/core/shims/cc/model.h" | ||||
| #include "tensorflow/lite/core/shims/cc/model_builder.h" | ||||
| #include "tensorflow/lite/core/shims/cc/tools/verifier.h" | ||||
| #include "tensorflow/lite/kernels/register.h" | ||||
| #include "tensorflow/lite/model.h" | ||||
| #include "tensorflow/lite/model_builder.h" | ||||
| #include "tensorflow/lite/tools/verifier.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -51,8 +51,8 @@ class ModelResources { | |||
|  public: | ||||
|   // Represents a TfLite model as a FlatBuffer.
 | ||||
|   using ModelPtr = | ||||
|       std::unique_ptr<tflite_shims::FlatBufferModel, | ||||
|                       std::function<void(tflite_shims::FlatBufferModel*)>>; | ||||
|       std::unique_ptr<tflite::FlatBufferModel, | ||||
|                       std::function<void(tflite::FlatBufferModel*)>>; | ||||
| 
 | ||||
|   // Takes the ownership of the provided ExternalFile proto and creates
 | ||||
|   // ModelResources from the proto and an op resolver object. A non-empty tag
 | ||||
|  | @ -61,7 +61,7 @@ class ModelResources { | |||
|   static absl::StatusOr<std::unique_ptr<ModelResources>> Create( | ||||
|       const std::string& tag, std::unique_ptr<proto::ExternalFile> model_file, | ||||
|       std::unique_ptr<tflite::OpResolver> op_resolver = | ||||
|           absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); | ||||
|           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); | ||||
| 
 | ||||
|   // Takes the ownership of the provided ExternalFile proto and creates
 | ||||
|   // ModelResources from the proto and an op resolver mediapipe packet. A
 | ||||
|  |  | |||
|  | @ -30,7 +30,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/core/proto/model_resources_calculator.pb.h" | ||||
| #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -124,7 +124,7 @@ void RunGraphWithGraphService(std::unique_ptr<ModelResources> model_resources, | |||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| class ModelResourcesCalculatorTest : public tflite_shims::testing::Test {}; | ||||
| class ModelResourcesCalculatorTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ModelResourcesCalculatorTest, MissingCalculatorOptions) { | ||||
|   auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>( | ||||
|  |  | |||
|  | @ -38,9 +38,9 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/kernels/builtin_op_kernels.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/kernels/builtin_op_kernels.h" | ||||
| #include "tensorflow/lite/mutable_op_resolver.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| namespace ops { | ||||
|  | @ -116,7 +116,7 @@ void CheckModelResourcesPackets(const ModelResources* model_resources) { | |||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| class ModelResourcesTest : public tflite_shims::testing::Test {}; | ||||
| class ModelResourcesTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ModelResourcesTest, CreateFromBinaryContent) { | ||||
|   auto model_file = std::make_unique<proto::ExternalFile>(); | ||||
|  | @ -211,7 +211,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsFromFile) { | |||
|   static constexpr char kCustomOpName[] = "MY_CUSTOM_OP"; | ||||
|   tflite::MutableOpResolver resolver; | ||||
|   resolver.AddBuiltin(::tflite::BuiltinOperator_ADD, | ||||
|                       ::tflite_shims::ops::builtin::Register_ADD()); | ||||
|                       ::tflite::ops::builtin::Register_ADD()); | ||||
|   resolver.AddCustom(kCustomOpName, | ||||
|                      ::tflite::ops::custom::Register_MY_CUSTOM_OP()); | ||||
| 
 | ||||
|  | @ -275,7 +275,7 @@ TEST_F(ModelResourcesTest, CreateSuccessWithCustomOpsPacket) { | |||
|   static constexpr char kCustomOpName[] = "MY_CUSTOM_OP"; | ||||
|   tflite::MutableOpResolver resolver; | ||||
|   resolver.AddBuiltin(::tflite::BuiltinOperator_ADD, | ||||
|                       ::tflite_shims::ops::builtin::Register_ADD()); | ||||
|                       ::tflite::ops::builtin::Register_ADD()); | ||||
|   resolver.AddCustom(kCustomOpName, | ||||
|                      ::tflite::ops::custom::Register_MY_CUSTOM_OP()); | ||||
| 
 | ||||
|  |  | |||
|  | @ -23,7 +23,11 @@ limitations under the License. | |||
| 
 | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/status/statusor.h" | ||||
| #include "absl/strings/match.h" | ||||
| #include "absl/strings/str_cat.h" | ||||
| #include "mediapipe/framework/calculator.pb.h" | ||||
| #include "mediapipe/framework/port/requires.h" | ||||
| #include "mediapipe/framework/port/status_macros.h" | ||||
| #include "mediapipe/tasks/cc/common.h" | ||||
| #include "mediapipe/tasks/cc/core/base_task_api.h" | ||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | ||||
|  | @ -54,6 +58,8 @@ class TaskApiFactory { | |||
|       std::unique_ptr<tflite::OpResolver> resolver, | ||||
|       PacketsCallback packets_callback = nullptr) { | ||||
|     bool found_task_subgraph = false; | ||||
|     // This for-loop ensures there's only one subgraph besides
 | ||||
|     // FlowLimiterCalculator.
 | ||||
|     for (const auto& node : graph_config.node()) { | ||||
|       if (node.calculator() == "FlowLimiterCalculator") { | ||||
|         continue; | ||||
|  | @ -64,13 +70,7 @@ class TaskApiFactory { | |||
|             "Task graph config should only contain one task subgraph node.", | ||||
|             MediaPipeTasksStatus::kInvalidTaskGraphConfigError); | ||||
|       } else { | ||||
|         if (!node.options().HasExtension(Options::ext)) { | ||||
|           return CreateStatusWithPayload( | ||||
|               absl::StatusCode::kInvalidArgument, | ||||
|               absl::StrCat(node.calculator(), | ||||
|                            " is missing the required task options field."), | ||||
|               MediaPipeTasksStatus::kInvalidTaskGraphConfigError); | ||||
|         } | ||||
|         MP_RETURN_IF_ERROR(CheckHasValidOptions<Options>(node)); | ||||
|         found_task_subgraph = true; | ||||
|       } | ||||
|     } | ||||
|  | @ -80,6 +80,35 @@ class TaskApiFactory { | |||
|                                  std::move(packets_callback))); | ||||
|     return std::make_unique<T>(std::move(runner)); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   template <typename Options> | ||||
|   static absl::Status CheckHasValidOptions( | ||||
|       const CalculatorGraphConfig::Node& node) { | ||||
|     if constexpr (mediapipe::Requires<Options>( | ||||
|                       [](auto&& o) -> decltype(o.ext) {})) { | ||||
|       if (node.options().HasExtension(Options::ext)) { | ||||
|         return absl::OkStatus(); | ||||
|       } | ||||
|     } else { | ||||
| #ifndef MEDIAPIPE_PROTO_LITE | ||||
|       for (const auto& option : node.node_options()) { | ||||
|         if (absl::StrContains(option.type_url(), | ||||
|                               Options::descriptor()->full_name())) { | ||||
|           return absl::OkStatus(); | ||||
|         } | ||||
|       } | ||||
| #else   // MEDIAPIPE_PROTO_LITE
 | ||||
|       // Skip the check for proto lite, as Options::descriptor() is unavailable.
 | ||||
|       return absl::OkStatus(); | ||||
| #endif  // MEDIAPIPE_PROTO_LITE
 | ||||
|     } | ||||
|     return CreateStatusWithPayload( | ||||
|         absl::StatusCode::kInvalidArgument, | ||||
|         absl::StrCat(node.calculator(), | ||||
|                      " is missing the required task options field."), | ||||
|         MediaPipeTasksStatus::kInvalidTaskGraphConfigError); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| }  // namespace core
 | ||||
|  |  | |||
|  | @ -32,7 +32,7 @@ limitations under the License. | |||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | ||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -112,7 +112,7 @@ CalculatorGraphConfig GetModelSidePacketsToStreamPacketsGraphConfig( | |||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| class TaskRunnerTest : public tflite_shims::testing::Test {}; | ||||
| class TaskRunnerTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(TaskRunnerTest, ConfigWithNoOutputStream) { | ||||
|   CalculatorGraphConfig proto = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||
|  |  | |||
							
								
								
									
										172
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/BUILD
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,172 @@ | |||
| # Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") | ||||
| 
 | ||||
| package(default_visibility = ["//mediapipe/tasks:internal"]) | ||||
| 
 | ||||
| licenses(["notice"]) | ||||
| 
 | ||||
| filegroup( | ||||
|     name = "testdata", | ||||
|     srcs = glob([ | ||||
|         "testdata/**", | ||||
|     ]), | ||||
| ) | ||||
| 
 | ||||
| filegroup( | ||||
|     name = "config_fbs", | ||||
|     srcs = ["config.fbs"], | ||||
| ) | ||||
| 
 | ||||
| flatbuffer_cc_library( | ||||
|     name = "config", | ||||
|     srcs = [ | ||||
|         "config.fbs", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| flatbuffer_cc_library( | ||||
|     name = "encoder_config", | ||||
|     srcs = [ | ||||
|         "encoder_config.fbs", | ||||
|     ], | ||||
|     includes = [":config_fbs"], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "utils", | ||||
|     hdrs = [ | ||||
|         "utils.h", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "double_array_trie", | ||||
|     hdrs = [ | ||||
|         "double_array_trie.h", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":config", | ||||
|         ":utils", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "double_array_trie_builder", | ||||
|     srcs = [ | ||||
|         "double_array_trie_builder.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
|         "double_array_trie_builder.h", | ||||
|     ], | ||||
|     deps = ["@darts_clone"], | ||||
| ) | ||||
| 
 | ||||
| cc_test( | ||||
|     name = "double_array_trie_test", | ||||
|     srcs = [ | ||||
|         "double_array_trie_test.cc", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":double_array_trie", | ||||
|         ":double_array_trie_builder", | ||||
|         ":encoder_config", | ||||
|         ":utils", | ||||
|         "//mediapipe/framework/port:gtest_main", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "sentencepiece_constants", | ||||
|     hdrs = ["sentencepiece_constants.h"], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "model_converter", | ||||
|     srcs = [ | ||||
|         "model_converter.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
|         "model_converter.h", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":config", | ||||
|         ":double_array_trie_builder", | ||||
|         ":encoder_config", | ||||
|         ":sentencepiece_constants", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@com_google_sentencepiece//src:sentencepiece_model_cc_proto", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "optimized_encoder", | ||||
|     srcs = [ | ||||
|         "optimized_encoder.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
|         "optimized_encoder.h", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":double_array_trie", | ||||
|         ":encoder_config", | ||||
|         ":utils", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "sentencepiece_tokenizer_tflite", | ||||
|     srcs = ["sentencepiece_tokenizer_tflite.cc"], | ||||
|     hdrs = ["sentencepiece_tokenizer_tflite.h"], | ||||
|     visibility = [ | ||||
|         "//visibility:public", | ||||
|     ], | ||||
|     deps = | ||||
|         [ | ||||
|             ":optimized_encoder", | ||||
|             "@flatbuffers", | ||||
|             "@org_tensorflow//tensorflow/lite:framework", | ||||
|             "@org_tensorflow//tensorflow/lite:string_util", | ||||
|             "@org_tensorflow//tensorflow/lite/c:common", | ||||
|             "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", | ||||
|             "@org_tensorflow//tensorflow/lite/kernels:kernel_util", | ||||
|             "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", | ||||
|         ], | ||||
| ) | ||||
| 
 | ||||
| cc_test( | ||||
|     name = "optimized_encoder_test", | ||||
|     srcs = [ | ||||
|         "optimized_encoder_test.cc", | ||||
|     ], | ||||
|     data = [ | ||||
|         ":testdata", | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":double_array_trie_builder", | ||||
|         ":encoder_config", | ||||
|         ":model_converter", | ||||
|         ":optimized_encoder", | ||||
|         "//mediapipe/framework/deps:file_path", | ||||
|         "//mediapipe/framework/port:gtest_main", | ||||
|         "@com_google_absl//absl/flags:flag", | ||||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/strings:str_format", | ||||
|         "@com_google_sentencepiece//src:sentencepiece_cc_proto", | ||||
|         "@com_google_sentencepiece//src:sentencepiece_processor", | ||||
|         "@org_tensorflow//tensorflow/core:lib", | ||||
|     ], | ||||
| ) | ||||
							
								
								
									
										25
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/config.fbs
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,25 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| namespace mediapipe.tflite_operations.sentencepiece; | ||||
| 
 | ||||
| table Trie { | ||||
|   nodes: [uint32]; | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| enum EncoderVersion: byte { | ||||
|   SENTENCE_PIECE = 0, | ||||
| } | ||||
|  | @ -0,0 +1,111 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/config_generated.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| // A trie node specifies a node in the tree, either an intermediate node or
 | ||||
| // a leaf node.
 | ||||
| // A leaf node contains the id as an int of the string match. This id is encoded
 | ||||
| // in the lower 31 bits, thus the number of distinct ids is 2^31.
 | ||||
| // An intermediate node has an associated label and an offset to its children.
 | ||||
| // The label is encoded in the least significant byte and must match the input
 | ||||
| // character during matching.
 | ||||
| 
 | ||||
| // A memory mappable trie, compatible with Darts::DoubleArray.
 | ||||
| class DoubleArrayTrie { | ||||
|  public: | ||||
|   struct Match { | ||||
|     Match() {} | ||||
|     Match(int id, int match_length) : id(id), match_length(match_length) {} | ||||
|     int id = -1; | ||||
|     int match_length = -1; | ||||
|     bool empty() const { return match_length == -1; } | ||||
|     bool operator==(const Match& m) const { | ||||
|       return m.id == id && m.match_length == match_length; | ||||
|     } | ||||
|   }; | ||||
| 
 | ||||
|   // nodes and nodes_length specify the array of the nodes of the trie.
 | ||||
|   explicit DoubleArrayTrie(const flatbuffers::Vector<uint32_t>* nodes) | ||||
|       : nodes_(nodes) {} | ||||
| 
 | ||||
|   // Finds matches that are prefixes of a string.
 | ||||
|   template <typename callback> | ||||
|   void IteratePrefixMatches(const utils::string_view& input, | ||||
|                             callback update_fn) const; | ||||
| 
 | ||||
|   // Finds the longest prefix match of a string.
 | ||||
|   Match LongestPrefixMatch(const utils::string_view& input) const { | ||||
|     Match match; | ||||
|     IteratePrefixMatches(input, [&match](const Match& m) { match = m; }); | ||||
|     return match; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   // Returns whether a node as a leaf as a child.
 | ||||
|   bool has_leaf(uint32_t i) const { return ((*nodes_)[i]) & 0x100; } | ||||
| 
 | ||||
|   // Returns a value associated with a node. Available when a node is a leaf.
 | ||||
|   int value(uint32_t i) const { | ||||
|     return static_cast<int>(((*nodes_)[i]) & 0x7fffffff); | ||||
|   } | ||||
| 
 | ||||
|   // Returns a label associated with a node.
 | ||||
|   // A leaf node will have the MSB set and thus return an invalid label.
 | ||||
|   int32_t label(uint32_t i) const { return ((*nodes_)[i]) & 0x800000ff; } | ||||
| 
 | ||||
|   // Returns offset to children.
 | ||||
|   int32_t offset(uint32_t i) const { | ||||
|     const uint32_t node = (*nodes_)[i]; | ||||
|     return (node >> 10) << ((node & 0x200) >> 6); | ||||
|   } | ||||
| 
 | ||||
|   const flatbuffers::Vector<uint32_t>* nodes_; | ||||
| }; | ||||
| 
 | ||||
| template <typename callback> | ||||
| void DoubleArrayTrie::IteratePrefixMatches(const utils::string_view& input, | ||||
|                                            callback update_fn) const { | ||||
|   if (nodes_->size() == 0) { | ||||
|     return; | ||||
|   } | ||||
|   uint32_t pos = offset(0); | ||||
|   for (int i = 0; i < input.length(); ++i) { | ||||
|     pos ^= static_cast<unsigned char>(input.at(i)); | ||||
|     if (pos < 0 || pos >= nodes_->size() || label(pos) != input.at(i)) { | ||||
|       // No match, exit.
 | ||||
|       return; | ||||
|     } | ||||
|     const bool node_has_leaf = has_leaf(pos); | ||||
|     pos ^= offset(pos); | ||||
|     if (pos < 0 || pos >= nodes_->size()) { | ||||
|       // We can get here only if the structure is corrupted.
 | ||||
|       return; | ||||
|     } | ||||
|     if (node_has_leaf) { | ||||
|       update_fn(Match(value(pos), i + 1)); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_H_
 | ||||
|  | @ -0,0 +1,75 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <memory> | ||||
| 
 | ||||
| #include "include/darts.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data) { | ||||
|   std::vector<int> ids; | ||||
|   ids.reserve(data.size()); | ||||
|   for (int i = 0; i < data.size(); ++i) { | ||||
|     ids.push_back(i); | ||||
|   } | ||||
|   return BuildTrie(data, ids); | ||||
| } | ||||
| 
 | ||||
| std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data, | ||||
|                                 const std::vector<int>& ids) { | ||||
|   // We make strong assumptions about binary structure of trie.
 | ||||
|   struct OneElement { | ||||
|     OneElement(const std::string* key_, int index_) | ||||
|         : key(key_), index(index_) {} | ||||
|     const std::string* key; | ||||
|     int index; | ||||
|     bool operator<(const OneElement& el) const { return *key < *el.key; } | ||||
|   }; | ||||
|   std::vector<OneElement> elements; | ||||
|   elements.reserve(data.size()); | ||||
|   auto data_iterator = std::begin(data); | ||||
|   auto ids_iterator = std::begin(ids); | ||||
|   for (; data_iterator != std::end(data) && ids_iterator != std::end(ids); | ||||
|        ++data_iterator, ++ids_iterator) { | ||||
|     elements.emplace_back(&(*data_iterator), *ids_iterator); | ||||
|   } | ||||
|   // Sort by keys.
 | ||||
|   std::sort(elements.begin(), elements.end()); | ||||
| 
 | ||||
|   // Create vectors to build the trie.
 | ||||
|   std::vector<const char*> strings; | ||||
|   std::vector<int32_t> indexes; | ||||
|   strings.reserve(data.size()); | ||||
|   indexes.reserve(data.size()); | ||||
|   for (const auto& el : elements) { | ||||
|     strings.push_back(el.key->c_str()); | ||||
|     indexes.push_back(el.index); | ||||
|   } | ||||
|   auto trie = std::make_unique<Darts::DoubleArray>(); | ||||
|   trie->build(data.size(), const_cast<char**>(&strings[0]), nullptr, | ||||
|               &indexes[0]); | ||||
|   // We make strong assumptions about internal Darts trie structure:
 | ||||
|   // - it is a vector of 32 bit signed integers
 | ||||
|   // - the "array" is the only one structure that contains all information about
 | ||||
|   // the trie.
 | ||||
|   const uint32_t* trie_data = static_cast<const uint32_t*>(trie->array()); | ||||
|   return std::vector<uint32_t>(trie_data, trie_data + trie->size()); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
|  | @ -0,0 +1,32 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_ | ||||
| 
 | ||||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data, | ||||
|                                 const std::vector<int>& ids); | ||||
| 
 | ||||
| // A variant where ids are indexes in data.
 | ||||
| std::vector<uint32_t> BuildTrie(const std::vector<std::string>& data); | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_DOUBLE_ARRAY_TRIE_BUILDER_H_
 | ||||
|  | @ -0,0 +1,73 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h" | ||||
| 
 | ||||
| #include "mediapipe/framework/port/gmock.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| TEST(DoubleArrayTrieTest, Match) { | ||||
|   flatbuffers::FlatBufferBuilder builder(1024); | ||||
|   const std::vector<std::string> test_strings = {"A", "AAX", "AA", "B"}; | ||||
|   const auto trie_vector = builder.CreateVector(BuildTrie(test_strings)); | ||||
|   TrieBuilder trie_builder(builder); | ||||
|   trie_builder.add_nodes(trie_vector); | ||||
|   const auto pieces = trie_builder.Finish(); | ||||
|   EncoderConfigBuilder ecb(builder); | ||||
|   ecb.add_pieces(pieces); | ||||
|   FinishEncoderConfigBuffer(builder, ecb.Finish()); | ||||
|   const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); | ||||
|   DoubleArrayTrie dat(config->pieces()->nodes()); | ||||
|   EXPECT_EQ(dat.LongestPrefixMatch(utils::string_view("AAL")), | ||||
|             DoubleArrayTrie::Match(2, 2)); | ||||
| 
 | ||||
|   std::vector<DoubleArrayTrie::Match> matches; | ||||
|   dat.IteratePrefixMatches( | ||||
|       utils::string_view("AAXL"), | ||||
|       [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); | ||||
|   EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(0, 1), | ||||
|                                             DoubleArrayTrie::Match(2, 2), | ||||
|                                             DoubleArrayTrie::Match(1, 3))); | ||||
| } | ||||
| 
 | ||||
| TEST(DoubleArrayTrieTest, ComplexMatch) { | ||||
|   flatbuffers::FlatBufferBuilder builder(1024); | ||||
|   const std::vector<std::string> test_strings = {"\xe2\x96\x81the", ",", "s", | ||||
|                                                  "\xe2\x96\x81Hello"}; | ||||
|   const std::vector<int> test_ids = {0, 5, 10, 15}; | ||||
|   const auto trie_vector = | ||||
|       builder.CreateVector(BuildTrie(test_strings, test_ids)); | ||||
|   TrieBuilder trie_builder(builder); | ||||
|   trie_builder.add_nodes(trie_vector); | ||||
|   const auto pieces = trie_builder.Finish(); | ||||
|   EncoderConfigBuilder ecb(builder); | ||||
|   ecb.add_pieces(pieces); | ||||
|   FinishEncoderConfigBuffer(builder, ecb.Finish()); | ||||
|   const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); | ||||
|   DoubleArrayTrie dat(config->pieces()->nodes()); | ||||
| 
 | ||||
|   std::vector<DoubleArrayTrie::Match> matches; | ||||
|   dat.IteratePrefixMatches( | ||||
|       utils::string_view("\xe2\x96\x81Hello"), | ||||
|       [&matches](const DoubleArrayTrie::Match& m) { matches.push_back(m); }); | ||||
|   EXPECT_THAT(matches, testing::ElementsAre(DoubleArrayTrie::Match(15, 8))); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
|  | @ -0,0 +1,52 @@ | |||
| // Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| include "config.fbs"; | ||||
| 
 | ||||
| namespace mediapipe.tflite_operations.sentencepiece; | ||||
| 
 | ||||
| table EncoderConfig { | ||||
|   // Version of the encoder. | ||||
|   version: EncoderVersion = SENTENCE_PIECE; | ||||
|   start_code: int32 = 0; | ||||
|   end_code: int32 = 0; | ||||
| 
 | ||||
|   unknown_code: int32 = -1; | ||||
|   // Weight of "unknown code" when encoding. "Penalty" because it usually has a | ||||
|   // big negative weight,less than any other sentencepiece. | ||||
|   unknown_penalty: float = 0; | ||||
| 
 | ||||
|   // The offset for encoding, usually used when codes with low codes are reserved | ||||
|   // for some special needs. | ||||
|   encoding_offset: int32; | ||||
| 
 | ||||
|   // String pieces for encoding. | ||||
|   pieces: Trie; | ||||
|   pieces_scores: [float]; | ||||
| 
 | ||||
|   // Normalization related parameters. | ||||
|   remove_extra_whitespaces: bool; | ||||
| 
 | ||||
|   // Add a whitespace prefix before encoding. | ||||
|   add_dummy_prefix: bool; | ||||
| 
 | ||||
|   // Escape whitespaces during encoding so the decoder can restore them exactly as | ||||
|   // in the input. | ||||
|   escape_whitespaces: bool; | ||||
| 
 | ||||
|   // Normalization parameters. | ||||
|   normalized_prefixes: Trie; | ||||
|   normalized_replacements: [byte]; | ||||
| } | ||||
| 
 | ||||
| root_type EncoderConfig; | ||||
|  | @ -0,0 +1,131 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" | ||||
| 
 | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/status/statusor.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_constants.h" | ||||
| #include "src/sentencepiece_model.pb.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| std::tuple<std::vector<uint32_t>, std::vector<int8_t>> | ||||
| DecodePrecompiledCharsmap( | ||||
|     const ::sentencepiece::NormalizerSpec& normalizer_spec) { | ||||
|   // This function "undoes" encoding done by
 | ||||
|   // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
 | ||||
|   const char* precompiled_map = normalizer_spec.precompiled_charsmap().data(); | ||||
|   const uint32_t trie_size = | ||||
|       *reinterpret_cast<const uint32_t*>(precompiled_map); | ||||
|   const uint32_t* trie_ptr = | ||||
|       reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t)); | ||||
|   const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>( | ||||
|       precompiled_map + sizeof(uint32_t) + trie_size); | ||||
|   const int normalized_size = normalizer_spec.precompiled_charsmap().length() - | ||||
|                               sizeof(uint32_t) - trie_size; | ||||
|   return std::make_tuple( | ||||
|       std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)), | ||||
|       std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size)); | ||||
| } | ||||
| 
 | ||||
| absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( | ||||
|     const std::string& model_config_str, int encoding_offset) { | ||||
|   ::sentencepiece::ModelProto model_config; | ||||
|   if (!model_config.ParseFromString(model_config_str)) { | ||||
|     return absl::InvalidArgumentError( | ||||
|         "Invalid configuration, can't parse SentencePiece model config " + | ||||
|         model_config.InitializationErrorString()); | ||||
|   } | ||||
|   // Convert sentencepieces.
 | ||||
|   std::vector<std::string> pieces; | ||||
|   pieces.reserve(model_config.pieces_size()); | ||||
|   std::vector<float> scores; | ||||
|   scores.reserve(model_config.pieces_size()); | ||||
|   std::vector<int> ids; | ||||
|   ids.reserve(model_config.pieces_size()); | ||||
|   float min_score = 0.0; | ||||
|   int index = 0; | ||||
|   for (const auto& piece : model_config.pieces()) { | ||||
|     switch (piece.type()) { | ||||
|       case ::sentencepiece::ModelProto::SentencePiece::NORMAL: | ||||
|       case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED: | ||||
|         pieces.push_back(piece.piece()); | ||||
|         ids.push_back(index); | ||||
|         if (piece.score() < min_score) { | ||||
|           min_score = piece.score(); | ||||
|         } | ||||
|         break; | ||||
|       case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN: | ||||
|       case ::sentencepiece::ModelProto::SentencePiece::CONTROL: | ||||
|         // Ignore unknown and control codes.
 | ||||
|         break; | ||||
|       default: | ||||
|         return absl::InvalidArgumentError("Invalid SentencePiece piece type " + | ||||
|                                           piece.piece()); | ||||
|     } | ||||
|     scores.push_back(piece.score()); | ||||
|     ++index; | ||||
|   } | ||||
|   flatbuffers::FlatBufferBuilder builder(1024); | ||||
|   const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids)); | ||||
|   const auto pieces_score_vector = builder.CreateVector(scores); | ||||
|   TrieBuilder pieces_trie_builder(builder); | ||||
|   pieces_trie_builder.add_nodes(pieces_trie_vector); | ||||
|   const auto pieces_trie_fbs = pieces_trie_builder.Finish(); | ||||
| 
 | ||||
|   // Converting normalization.
 | ||||
|   const auto normalization = | ||||
|       DecodePrecompiledCharsmap(model_config.normalizer_spec()); | ||||
|   const auto normalization_trie = std::get<0>(normalization); | ||||
|   const auto normalization_strings = std::get<1>(normalization); | ||||
|   const auto normalization_trie_vector = | ||||
|       builder.CreateVector(normalization_trie); | ||||
|   TrieBuilder normalization_trie_builder(builder); | ||||
|   normalization_trie_builder.add_nodes(normalization_trie_vector); | ||||
|   const auto normalization_trie_fbs = normalization_trie_builder.Finish(); | ||||
|   const auto normalization_strings_fbs = | ||||
|       builder.CreateVector(normalization_strings); | ||||
| 
 | ||||
|   EncoderConfigBuilder ecb(builder); | ||||
|   ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE); | ||||
|   ecb.add_start_code(model_config.trainer_spec().bos_id()); | ||||
|   ecb.add_end_code(model_config.trainer_spec().eos_id()); | ||||
|   ecb.add_unknown_code(model_config.trainer_spec().unk_id()); | ||||
|   ecb.add_unknown_penalty(min_score - kUnkPenalty); | ||||
|   ecb.add_encoding_offset(encoding_offset); | ||||
|   ecb.add_pieces(pieces_trie_fbs); | ||||
|   ecb.add_pieces_scores(pieces_score_vector); | ||||
|   ecb.add_remove_extra_whitespaces( | ||||
|       model_config.normalizer_spec().remove_extra_whitespaces()); | ||||
|   ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix()); | ||||
|   ecb.add_escape_whitespaces( | ||||
|       model_config.normalizer_spec().escape_whitespaces()); | ||||
|   ecb.add_normalized_prefixes(normalization_trie_fbs); | ||||
|   ecb.add_normalized_replacements(normalization_strings_fbs); | ||||
|   FinishEncoderConfigBuffer(builder, ecb.Finish()); | ||||
|   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), | ||||
|                      builder.GetSize()); | ||||
| } | ||||
| 
 | ||||
| std::string ConvertSentencepieceModel(const std::string& model_string) { | ||||
|   const auto result = ConvertSentencepieceModelToFlatBuffer(model_string); | ||||
|   assert(result.status().ok()); | ||||
|   return result.value(); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
|  | @ -0,0 +1,33 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_ | ||||
| 
 | ||||
| #include <string> | ||||
| 
 | ||||
| #include "absl/status/statusor.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| // Converts Sentencepiece configuration to flatbuffer format.
 | ||||
| // encoding_offset is used by some encoders that combine different encodings.
 | ||||
| absl::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer( | ||||
|     const std::string& model_config_str, int encoding_offset = 0); | ||||
| std::string ConvertSentencepieceModel(const std::string& model_string); | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_MODEL_CONVERTER_H_
 | ||||
|  | @ -0,0 +1,236 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <tuple> | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| namespace { | ||||
| 
 | ||||
| const char kSpaceSymbol[] = "\xe2\x96\x81"; | ||||
| 
 | ||||
| template <typename processing_callback> | ||||
| std::tuple<std::string, std::vector<int>> process_string( | ||||
|     const std::string& input, const std::vector<int>& offsets, | ||||
|     const processing_callback& pc) { | ||||
|   std::string result_string; | ||||
|   result_string.reserve(input.size()); | ||||
|   std::vector<int> result_offsets; | ||||
|   result_offsets.reserve(offsets.size()); | ||||
|   for (int i = 0, j = 0; i < input.size();) { | ||||
|     auto result = pc(input.data() + i, input.size() - i); | ||||
|     auto consumed = std::get<0>(result); | ||||
|     auto new_string = std::get<1>(result); | ||||
|     if (consumed == 0) { | ||||
|       // Skip the current byte and move forward.
 | ||||
|       result_string.push_back(input[i]); | ||||
|       result_offsets.push_back(offsets[j]); | ||||
|       i++; | ||||
|       j++; | ||||
|       continue; | ||||
|     } | ||||
|     result_string.append(new_string.data(), new_string.length()); | ||||
|     for (int i = 0; i < new_string.length(); ++i) { | ||||
|       result_offsets.push_back(offsets[j]); | ||||
|     } | ||||
|     j += consumed; | ||||
|     i += consumed; | ||||
|   } | ||||
|   return std::make_tuple(result_string, result_offsets); | ||||
| } | ||||
| 
 | ||||
| inline char is_whitespace(char c) { | ||||
|   return c == ' ' || c == '\t' || c == '\r' || c == '\n'; | ||||
| } | ||||
| 
 | ||||
| std::tuple<int, utils::string_view> remove_extra_whitespaces(const char* data, | ||||
|                                                              int len) { | ||||
|   if (len == 0 || !is_whitespace(*data)) { | ||||
|     return std::make_tuple(0, utils::string_view(nullptr, 0)); | ||||
|   } | ||||
|   int num_consumed = 1; | ||||
|   for (; num_consumed < len && is_whitespace(data[num_consumed]); | ||||
|        ++num_consumed) { | ||||
|   } | ||||
|   return num_consumed > 1 | ||||
|              ? std::make_tuple(num_consumed, utils::string_view(" ", 1)) | ||||
|              : std::make_tuple(0, utils::string_view(nullptr, 0)); | ||||
| } | ||||
| 
 | ||||
| std::tuple<int, utils::string_view> find_replacement( | ||||
|     const char* data, int len, const DoubleArrayTrie& dat, | ||||
|     const flatbuffers::Vector<int8_t>& replacements) { | ||||
|   const auto max_match = dat.LongestPrefixMatch(utils::string_view(data, len)); | ||||
|   if (!max_match.empty()) { | ||||
|     // Because flatbuffer byte is signed char which is not the same as char,
 | ||||
|     // there is the reinterpret_cast here.
 | ||||
|     const char* replaced_string_ptr = | ||||
|         reinterpret_cast<const char*>(replacements.data() + max_match.id); | ||||
|     return std::make_tuple(max_match.match_length, | ||||
|                            utils::string_view(replaced_string_ptr)); | ||||
|   } | ||||
|   return std::make_tuple(0, utils::string_view(nullptr, 0)); | ||||
| } | ||||
| }  // namespace
 | ||||
| 
 | ||||
| std::tuple<std::string, std::vector<int>> NormalizeString( | ||||
|     const std::string& in_string, const EncoderConfig& config) { | ||||
|   std::vector<int> output_offsets; | ||||
|   std::string result = in_string; | ||||
|   output_offsets.reserve(in_string.length()); | ||||
|   for (int i = 0; i < in_string.length(); ++i) { | ||||
|     output_offsets.push_back(i); | ||||
|   } | ||||
|   if (in_string.empty()) { | ||||
|     return std::make_tuple(result, output_offsets); | ||||
|   } | ||||
|   if (config.add_dummy_prefix()) { | ||||
|     result.insert(result.begin(), ' '); | ||||
|     output_offsets.insert(output_offsets.begin(), 0); | ||||
|   } | ||||
|   // Greedely replace normalized_prefixes with normalized_replacements
 | ||||
|   if (config.normalized_prefixes() != nullptr && | ||||
|       config.normalized_replacements() != nullptr) { | ||||
|     const DoubleArrayTrie normalized_prefixes_matcher( | ||||
|         config.normalized_prefixes()->nodes()); | ||||
|     const auto norm_replace = [&config, &normalized_prefixes_matcher]( | ||||
|                                   const char* data, int len) { | ||||
|       return find_replacement(data, len, normalized_prefixes_matcher, | ||||
|                               *config.normalized_replacements()); | ||||
|     }; | ||||
|     std::tie(result, output_offsets) = | ||||
|         process_string(result, output_offsets, norm_replace); | ||||
|   } | ||||
|   if (config.remove_extra_whitespaces()) { | ||||
|     std::tie(result, output_offsets) = | ||||
|         process_string(result, output_offsets, remove_extra_whitespaces); | ||||
|     if (!result.empty() && is_whitespace(result.back())) { | ||||
|       result.pop_back(); | ||||
|       output_offsets.pop_back(); | ||||
|     } | ||||
|   } | ||||
|   if (config.escape_whitespaces()) { | ||||
|     const auto replace_whitespaces = [](const char* data, int len) { | ||||
|       if (len > 0 && is_whitespace(*data)) { | ||||
|         return std::make_tuple(1, utils::string_view(kSpaceSymbol)); | ||||
|       } | ||||
|       return std::make_tuple(0, utils::string_view(nullptr, 0)); | ||||
|     }; | ||||
|     std::tie(result, output_offsets) = | ||||
|         process_string(result, output_offsets, replace_whitespaces); | ||||
|   } | ||||
| 
 | ||||
|   return std::make_tuple(result, output_offsets); | ||||
| } | ||||
| 
 | ||||
| EncoderResult EncodeNormalizedString(const std::string& str, | ||||
|                                      const std::vector<int>& offsets, | ||||
|                                      const EncoderConfig& config, bool add_bos, | ||||
|                                      bool add_eos, bool reverse) { | ||||
|   const DoubleArrayTrie piece_matcher(config.pieces()->nodes()); | ||||
|   const flatbuffers::Vector<float>* piece_scores = config.pieces_scores(); | ||||
|   const int unknown_code = config.unknown_code(); | ||||
|   const float unknown_penalty = config.unknown_penalty(); | ||||
|   struct LatticeElement { | ||||
|     float score = 0; | ||||
|     int code = -1; | ||||
|     int prev_position = -1; | ||||
|     LatticeElement(float score_, int code_, int prev_position_) | ||||
|         : score(score_), code(code_), prev_position(prev_position_) {} | ||||
|     LatticeElement() {} | ||||
|   }; | ||||
|   const int length = str.length(); | ||||
|   std::vector<LatticeElement> lattice(length + 1); | ||||
|   for (int i = 0; i < length; ++i) { | ||||
|     if (i > 0 && lattice[i].prev_position < 0) { | ||||
|       // This state is unreachable.
 | ||||
|       continue; | ||||
|     } | ||||
|     if (unknown_code >= 0) { | ||||
|       // Put unknown code.
 | ||||
|       const float penalized_score = lattice[i].score + unknown_penalty; | ||||
|       const int pos = i + 1; | ||||
|       LatticeElement& current_element = lattice[pos]; | ||||
|       if (current_element.prev_position < 0 || | ||||
|           current_element.score < penalized_score) { | ||||
|         current_element = LatticeElement( | ||||
|             penalized_score, unknown_code, | ||||
|             // If the current state is already reached by unknown code, merge
 | ||||
|             // states.
 | ||||
|             lattice[i].code == unknown_code ? lattice[i].prev_position : i); | ||||
|       } | ||||
|     } | ||||
|     auto lattice_update = [&lattice, i, | ||||
|                            piece_scores](const DoubleArrayTrie::Match& m) { | ||||
|       LatticeElement& target_element = lattice[i + m.match_length]; | ||||
|       const float score = lattice[i].score + (*piece_scores)[m.id]; | ||||
|       if (target_element.prev_position < 0 || target_element.score < score) { | ||||
|         target_element = LatticeElement(score, m.id, i); | ||||
|       } | ||||
|     }; | ||||
|     piece_matcher.IteratePrefixMatches( | ||||
|         utils::string_view(str.data() + i, length - i), lattice_update); | ||||
|   } | ||||
| 
 | ||||
|   EncoderResult result; | ||||
|   if (add_eos) { | ||||
|     result.codes.push_back(config.end_code()); | ||||
|     result.offsets.push_back(length); | ||||
|   } | ||||
|   if (lattice[length].prev_position >= 0) { | ||||
|     for (int pos = length; pos > 0;) { | ||||
|       auto code = lattice[pos].code; | ||||
|       if (code != config.unknown_code()) { | ||||
|         code += config.encoding_offset(); | ||||
|       } | ||||
|       result.codes.push_back(code); | ||||
|       pos = lattice[pos].prev_position; | ||||
|       result.offsets.push_back(offsets[pos]); | ||||
|     } | ||||
|   } | ||||
|   if (add_bos) { | ||||
|     result.codes.push_back(config.start_code()); | ||||
|     result.offsets.push_back(0); | ||||
|   } | ||||
|   if (!reverse) { | ||||
|     std::reverse(result.codes.begin(), result.codes.end()); | ||||
|     std::reverse(result.offsets.begin(), result.offsets.end()); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| EncoderResult EncodeString(const std::string& string, const void* config_buffer, | ||||
|                            bool add_bos, bool add_eos, bool reverse) { | ||||
|   // Get the config from the buffer.
 | ||||
|   const EncoderConfig* config = GetEncoderConfig(config_buffer); | ||||
|   if (config->version() != EncoderVersion::EncoderVersion_SENTENCE_PIECE) { | ||||
|     EncoderResult result; | ||||
|     result.type = EncoderResultType::WRONG_CONFIG; | ||||
|     return result; | ||||
|   } | ||||
|   std::string normalized_string; | ||||
|   std::vector<int> offsets; | ||||
|   std::tie(normalized_string, offsets) = NormalizeString(string, *config); | ||||
|   return EncodeNormalizedString(normalized_string, offsets, *config, add_bos, | ||||
|                                 add_eos, reverse); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
|  | @ -0,0 +1,46 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_ | ||||
| 
 | ||||
| // Sentencepiece encoder optimized with memmapped model.
 | ||||
| 
 | ||||
| #include <string> | ||||
| #include <tuple> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| enum class EncoderResultType { SUCCESS = 0, WRONG_CONFIG = 1 }; | ||||
| 
 | ||||
| struct EncoderResult { | ||||
|   EncoderResultType type = EncoderResultType::SUCCESS; | ||||
|   std::vector<int> codes; | ||||
|   std::vector<int> offsets; | ||||
| }; | ||||
| std::tuple<std::string, std::vector<int>> NormalizeString( | ||||
|     const std::string& in_string, const EncoderConfig& config); | ||||
| 
 | ||||
| // Encodes one string and returns ids and offsets. Takes the configuration as a
 | ||||
| // type-erased buffer.
 | ||||
| EncoderResult EncodeString(const std::string& string, const void* config_buffer, | ||||
|                            bool add_bos, bool add_eos, bool reverse); | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_OPTIMIZED_ENCODER_H_
 | ||||
|  | @ -0,0 +1,171 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" | ||||
| 
 | ||||
| #include <fstream> | ||||
| 
 | ||||
| #include "absl/flags/flag.h" | ||||
| #include "absl/status/status.h" | ||||
| #include "absl/strings/str_format.h" | ||||
| #include "mediapipe/framework/deps/file_path.h" | ||||
| #include "mediapipe/framework/port/gmock.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/double_array_trie_builder.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/encoder_config_generated.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/model_converter.h" | ||||
| #include "src/sentencepiece.pb.h" | ||||
| #include "src/sentencepiece_processor.h" | ||||
| #include "tensorflow/core/platform/env.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| namespace internal { | ||||
| 
 | ||||
| tensorflow::Status TFReadFileToString(const std::string& filepath, | ||||
|                                       std::string* data) { | ||||
|   return tensorflow::ReadFileToString(tensorflow::Env::Default(), filepath, | ||||
|                                       data); | ||||
| } | ||||
| 
 | ||||
| absl::Status StdReadFileToString(const std::string& filepath, | ||||
|                                  std::string* data) { | ||||
|   std::ifstream infile(filepath); | ||||
|   if (!infile.is_open()) { | ||||
|     return absl::NotFoundError( | ||||
|         absl::StrFormat("Error when opening %s", filepath)); | ||||
|   } | ||||
|   std::string contents((std::istreambuf_iterator<char>(infile)), | ||||
|                        (std::istreambuf_iterator<char>())); | ||||
|   data->append(contents); | ||||
|   infile.close(); | ||||
|   return absl::OkStatus(); | ||||
| } | ||||
| }  // namespace internal
 | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::file::JoinPath; | ||||
| 
 | ||||
| static char kConfigFilePath[] = | ||||
|     "/mediapipe/tasks/cc/text/custom_ops/" | ||||
|     "sentencepiece/testdata/sentencepiece.model"; | ||||
| 
 | ||||
| TEST(OptimizedEncoder, NormalizeStringWhitestpaces) { | ||||
|   flatbuffers::FlatBufferBuilder builder(1024); | ||||
|   EncoderConfigBuilder ecb(builder); | ||||
|   ecb.add_remove_extra_whitespaces(true); | ||||
|   ecb.add_add_dummy_prefix(true); | ||||
|   ecb.add_escape_whitespaces(true); | ||||
|   FinishEncoderConfigBuffer(builder, ecb.Finish()); | ||||
|   const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); | ||||
|   { | ||||
|     const auto result = NormalizeString("x  y", *config); | ||||
|     const auto res_string = std::get<0>(result); | ||||
|     const auto offsets = std::get<1>(result); | ||||
|     EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); | ||||
|     EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 0, 1, 1, 1, 3)); | ||||
|   } | ||||
|   { | ||||
|     const auto result = NormalizeString("\tx  y\n", *config); | ||||
|     const auto res_string = std::get<0>(result); | ||||
|     const auto offsets = std::get<1>(result); | ||||
|     EXPECT_EQ(res_string, "\xe2\x96\x81x\xe2\x96\x81y"); | ||||
|     EXPECT_THAT(offsets, ::testing::ElementsAre(0, 0, 0, 1, 2, 2, 2, 4)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(OptimizedEncoder, NormalizeStringReplacement) { | ||||
|   flatbuffers::FlatBufferBuilder builder(1024); | ||||
|   const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA"}; | ||||
|   const char norm_replacements[] = "A1\0A2\0A3\0A4"; | ||||
|   const auto trie_vector = | ||||
|       builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9})); | ||||
|   const auto norm_r = builder.CreateVector<int8_t>( | ||||
|       reinterpret_cast<const signed char*>(norm_replacements), | ||||
|       sizeof(norm_replacements)); | ||||
|   TrieBuilder trie_builder(builder); | ||||
|   trie_builder.add_nodes(trie_vector); | ||||
|   const auto norm_p = trie_builder.Finish(); | ||||
|   EncoderConfigBuilder ecb(builder); | ||||
|   ecb.add_remove_extra_whitespaces(false); | ||||
|   ecb.add_normalized_prefixes(norm_p); | ||||
|   ecb.add_normalized_replacements(norm_r); | ||||
|   FinishEncoderConfigBuffer(builder, ecb.Finish()); | ||||
|   const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); | ||||
|   { | ||||
|     const auto result = NormalizeString("ABAABAAABAAAA", *config); | ||||
|     const auto res_string = std::get<0>(result); | ||||
|     const auto offsets = std::get<1>(result); | ||||
|     EXPECT_EQ(res_string, "A1BA2BA3BA4"); | ||||
|     EXPECT_THAT(offsets, | ||||
|                 ::testing::ElementsAre(0, 0, 1, 2, 2, 4, 5, 5, 8, 9, 9)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(OptimizedEncoder, NormalizeStringWhitespacesRemove) { | ||||
|   flatbuffers::FlatBufferBuilder builder(1024); | ||||
|   const std::vector<std::string> norm_prefixes = {"A", "AA", "AAA", "AAAA", | ||||
|                                                   "X"}; | ||||
|   const char norm_replacements[] = "A1\0A2\0A3\0A4\0 "; | ||||
|   const auto trie_vector = | ||||
|       builder.CreateVector(BuildTrie(norm_prefixes, {0, 3, 6, 9, 12})); | ||||
|   const auto norm_r = builder.CreateVector<int8_t>( | ||||
|       reinterpret_cast<const signed char*>(norm_replacements), | ||||
|       sizeof(norm_replacements)); | ||||
|   TrieBuilder trie_builder(builder); | ||||
|   trie_builder.add_nodes(trie_vector); | ||||
|   const auto norm_p = trie_builder.Finish(); | ||||
|   EncoderConfigBuilder ecb(builder); | ||||
|   ecb.add_remove_extra_whitespaces(true); | ||||
|   ecb.add_normalized_prefixes(norm_p); | ||||
|   ecb.add_normalized_replacements(norm_r); | ||||
|   FinishEncoderConfigBuffer(builder, ecb.Finish()); | ||||
|   const EncoderConfig* config = GetEncoderConfig(builder.GetBufferPointer()); | ||||
|   { | ||||
|     const auto result = NormalizeString("XXABAABAAABAAAA", *config); | ||||
|     const auto res_string = std::get<0>(result); | ||||
|     const auto offsets = std::get<1>(result); | ||||
|     EXPECT_EQ(res_string, " A1BA2BA3BA4"); | ||||
|     EXPECT_THAT(offsets, | ||||
|                 ::testing::ElementsAre(0, 2, 2, 3, 4, 4, 6, 7, 7, 10, 11, 11)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST(OptimizedEncoder, ConfigConverter) { | ||||
|   std::string config; | ||||
|   auto status = | ||||
|       internal::TFReadFileToString(JoinPath("./", kConfigFilePath), &config); | ||||
|   ASSERT_TRUE(status.ok()); | ||||
| 
 | ||||
|   ::sentencepiece::SentencePieceProcessor processor; | ||||
|   ASSERT_TRUE(processor.LoadFromSerializedProto(config).ok()); | ||||
|   const auto converted_model = ConvertSentencepieceModel(config); | ||||
|   const std::string test_string("Hello world!\\xF0\\x9F\\x8D\\x95"); | ||||
|   const auto encoded = | ||||
|       EncodeString(test_string, converted_model.data(), false, false, false); | ||||
|   ASSERT_EQ(encoded.codes.size(), encoded.offsets.size()); | ||||
| 
 | ||||
|   ::sentencepiece::SentencePieceText reference_encoded; | ||||
|   ASSERT_TRUE(processor.Encode(test_string, &reference_encoded).ok()); | ||||
|   EXPECT_EQ(encoded.codes.size(), reference_encoded.pieces_size()); | ||||
|   for (int i = 0; i < encoded.codes.size(); ++i) { | ||||
|     EXPECT_EQ(encoded.codes[i], reference_encoded.pieces(i).id()); | ||||
|     EXPECT_EQ(encoded.offsets[i], reference_encoded.pieces(i).begin()); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
|  | @ -0,0 +1,38 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_ | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| // The constant is copied from
 | ||||
| // https://github.com/google/sentencepiece/blob/master/src/unigram_model.cc
 | ||||
| constexpr float kUnkPenalty = 10.0; | ||||
| 
 | ||||
| // These constants are copied from
 | ||||
| // https://github.com/google/sentencepiece/blob/master/src/sentencepiece_processor.cc
 | ||||
| //
 | ||||
| // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK).
 | ||||
| constexpr char kSpaceSymbol[] = "\xe2\x96\x81"; | ||||
| 
 | ||||
| // Encodes <unk> into U+2047 (DOUBLE QUESTION MARK),
 | ||||
| // since this character can be useful both for user and
 | ||||
| // developer. We can easily figure out that <unk> is emitted.
 | ||||
| constexpr char kDefaultUnknownSymbol[] = " \xE2\x81\x87 "; | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_CONSTANTS_H_
 | ||||
|  | @ -0,0 +1,129 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/sentencepiece_tokenizer_tflite.h" | ||||
| 
 | ||||
| #include "flatbuffers/flexbuffers.h" | ||||
| #include "mediapipe/tasks/cc/text/custom_ops/sentencepiece/optimized_encoder.h" | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/context.h" | ||||
| #include "tensorflow/lite/kernels/internal/tensor.h" | ||||
| #include "tensorflow/lite/kernels/kernel_util.h" | ||||
| #include "tensorflow/lite/model.h" | ||||
| #include "tensorflow/lite/string_util.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations { | ||||
| namespace sentencepiece::tokenizer { | ||||
| namespace { | ||||
| 
 | ||||
| using ::tflite::SetTensorToDynamic; | ||||
| 
 | ||||
| constexpr int kSPModelIndex = 0; | ||||
| constexpr int kInputIndex = 1; | ||||
| constexpr int kAddBOSInput = 4; | ||||
| constexpr int kAddEOSInput = 5; | ||||
| constexpr int kReverseInput = 6; | ||||
| 
 | ||||
| constexpr int kOutputValuesInd = 0; | ||||
| constexpr int kOutputSplitsInd = 1; | ||||
| 
 | ||||
| TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) { | ||||
|   TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size()); | ||||
|   int index = 0; | ||||
|   for (const int size : sizes) { | ||||
|     array_size->data[index++] = size; | ||||
|   } | ||||
|   return array_size; | ||||
| } | ||||
| }  // namespace
 | ||||
| 
 | ||||
| // Initializes text encoder object from serialized parameters.
 | ||||
| void* Initialize(TfLiteContext* /*context*/, const char* /*buffer*/, | ||||
|                  size_t /*length*/) { | ||||
|   return nullptr; | ||||
| } | ||||
| void Free(TfLiteContext* /*context*/, void* /*buffer*/) {} | ||||
| 
 | ||||
| TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { | ||||
|   // TODO: Add checks for input and output tensors.
 | ||||
|   TfLiteTensor& output_values = | ||||
|       context->tensors[node->outputs->data[kOutputValuesInd]]; | ||||
|   SetTensorToDynamic(&output_values); | ||||
| 
 | ||||
|   TfLiteTensor& output_splits = | ||||
|       context->tensors[node->outputs->data[kOutputSplitsInd]]; | ||||
|   SetTensorToDynamic(&output_splits); | ||||
|   return kTfLiteOk; | ||||
| } | ||||
| 
 | ||||
| TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { | ||||
|   const TfLiteTensor& model_tensor = | ||||
|       context->tensors[node->inputs->data[kSPModelIndex]]; | ||||
|   const auto model_buffer_data = model_tensor.data.data; | ||||
|   const TfLiteTensor& input_text = | ||||
|       context->tensors[node->inputs->data[kInputIndex]]; | ||||
| 
 | ||||
|   const TfLiteTensor add_bos_tensor = | ||||
|       context->tensors[node->inputs->data[kAddBOSInput]]; | ||||
|   const bool add_bos = add_bos_tensor.data.b[0]; | ||||
|   const TfLiteTensor add_eos_tensor = | ||||
|       context->tensors[node->inputs->data[kAddEOSInput]]; | ||||
|   const bool add_eos = add_eos_tensor.data.b[0]; | ||||
|   const TfLiteTensor reverse_tensor = | ||||
|       context->tensors[node->inputs->data[kReverseInput]]; | ||||
|   const bool reverse = reverse_tensor.data.b[0]; | ||||
| 
 | ||||
|   std::vector<int32> encoded; | ||||
|   std::vector<int32> splits; | ||||
|   const int num_strings = tflite::GetStringCount(&input_text); | ||||
|   for (int i = 0; i < num_strings; ++i) { | ||||
|     const auto strref = tflite::GetString(&input_text, i); | ||||
|     const auto res = EncodeString(std::string(strref.str, strref.len), | ||||
|                                   model_buffer_data, add_bos, add_eos, reverse); | ||||
|     TF_LITE_ENSURE_MSG(context, res.type == EncoderResultType::SUCCESS, | ||||
|                        "Sentencepiece conversion failed"); | ||||
|     std::copy(res.codes.begin(), res.codes.end(), std::back_inserter(encoded)); | ||||
|     splits.emplace_back(encoded.size()); | ||||
|   } | ||||
| 
 | ||||
|   TfLiteTensor& output_values = | ||||
|       context->tensors[node->outputs->data[kOutputValuesInd]]; | ||||
|   TF_LITE_ENSURE_OK(context, | ||||
|                     context->ResizeTensor( | ||||
|                         context, &output_values, | ||||
|                         CreateSizeArray({static_cast<int>(encoded.size())}))); | ||||
|   int32_t* output_values_flat = output_values.data.i32; | ||||
|   std::copy(encoded.begin(), encoded.end(), output_values_flat); | ||||
|   TfLiteTensor& output_splits = | ||||
|       context->tensors[node->outputs->data[kOutputSplitsInd]]; | ||||
|   TF_LITE_ENSURE_OK( | ||||
|       context, context->ResizeTensor( | ||||
|                    context, &output_splits, | ||||
|                    CreateSizeArray({static_cast<int>(splits.size() + 1)}))); | ||||
|   int32_t* output_splits_flat = output_splits.data.i32; | ||||
|   *output_splits_flat = 0; | ||||
|   std::copy(splits.begin(), splits.end(), output_splits_flat + 1); | ||||
|   return kTfLiteOk; | ||||
| } | ||||
| }  // namespace sentencepiece::tokenizer
 | ||||
| 
 | ||||
| TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER() { | ||||
|   static TfLiteRegistration r = { | ||||
|       sentencepiece::tokenizer::Initialize, sentencepiece::tokenizer::Free, | ||||
|       sentencepiece::tokenizer::Prepare, sentencepiece::tokenizer::Eval}; | ||||
|   return &r; | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
|  | @ -0,0 +1,27 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_ | ||||
| 
 | ||||
| #include "tensorflow/lite/kernels/register.h" | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations { | ||||
| 
 | ||||
| TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER(); | ||||
| 
 | ||||
| }  // namespace mediapipe::tflite_operations
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_SENTENCEPIECE_TOKENIZER_TFLITE_H_
 | ||||
							
								
								
									
										
											BIN
										
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/testdata/sentencepiece.model
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										60
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								mediapipe/tasks/cc/text/custom_ops/sentencepiece/utils.h
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,60 @@ | |||
| /* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ | ||||
| #define MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_ | ||||
| 
 | ||||
| #include <ostream> | ||||
| #include <string> | ||||
| 
 | ||||
| namespace mediapipe::tflite_operations::sentencepiece { | ||||
| 
 | ||||
| // AOSP and WASM doesn't support string_view,
 | ||||
| // we put here a minimal re-implementation.
 | ||||
| namespace utils { | ||||
| 
 | ||||
| class string_view { | ||||
|  public: | ||||
|   explicit string_view(const std::string& s) | ||||
|       : str_(s.data()), len_(s.length()) {} | ||||
|   string_view(const char* str, int len) : str_(str), len_(len) {} | ||||
|   // A constructor from c string.
 | ||||
|   explicit string_view(const char* s) : str_(s), len_(strlen(s)) {} | ||||
| 
 | ||||
|   int length() const { return len_; } | ||||
|   const char* data() const { return str_; } | ||||
|   bool empty() const { return len_ == 0; } | ||||
|   unsigned char at(int i) const { return str_[i]; } | ||||
| 
 | ||||
|  private: | ||||
|   const char* str_ = nullptr; | ||||
|   const int len_ = 0; | ||||
| }; | ||||
| 
 | ||||
| inline std::ostream& operator<<(std::ostream& os, const string_view& sv) { | ||||
|   os << std::string(sv.data(), sv.length()); | ||||
|   return os; | ||||
| } | ||||
| inline bool operator==(const string_view& view1, const string_view& view2) { | ||||
|   if (view1.length() != view2.length()) { | ||||
|     return false; | ||||
|   } | ||||
|   return memcmp(view1.data(), view2.data(), view1.length()) == 0; | ||||
| } | ||||
| 
 | ||||
| }  // namespace utils
 | ||||
| }  // namespace mediapipe::tflite_operations::sentencepiece
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_TASKS_CC_TEXT_CUSTOM_OPS_SENTENCEPIECE_UTILS_H_
 | ||||
|  | @ -32,7 +32,7 @@ limitations under the License. | |||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/tasks/cc/common.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::text::language_detector { | ||||
| namespace { | ||||
|  | @ -75,7 +75,7 @@ absl::Status MatchesLanguageDetectorResult( | |||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| class LanguageDetectorTest : public tflite_shims::testing::Test {}; | ||||
| class LanguageDetectorTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(LanguageDetectorTest, CreateFailsWithMissingModel) { | ||||
|   auto options = std::make_unique<LanguageDetectorOptions>(); | ||||
|  |  | |||
|  | @ -89,7 +89,7 @@ cc_test( | |||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_absl//absl/strings:cord", | ||||
|         "@com_google_sentencepiece//src:sentencepiece_processor", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -36,7 +36,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/components/containers/category.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||
| #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::text::text_classifier { | ||||
| namespace { | ||||
|  | @ -87,7 +87,7 @@ void ExpectApproximatelyEqual(const TextClassifierResult& actual, | |||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| class TextClassifierTest : public tflite_shims::testing::Test {}; | ||||
| class TextClassifierTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { | ||||
|   auto options = std::make_unique<TextClassifierOptions>(); | ||||
|  |  | |||
|  | @ -91,6 +91,6 @@ cc_test( | |||
|         "@com_google_absl//absl/status", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@com_google_sentencepiece//src:sentencepiece_processor", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -26,7 +26,7 @@ limitations under the License. | |||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/tasks/cc/common.h" | ||||
| #include "mediapipe/tasks/cc/components/containers/embedding_result.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::text::text_embedder { | ||||
| namespace { | ||||
|  | @ -39,6 +39,8 @@ constexpr char kMobileBert[] = "mobilebert_embedding_with_metadata.tflite"; | |||
| // Embedding model with regex preprocessing.
 | ||||
| constexpr char kRegexOneEmbeddingModel[] = | ||||
|     "regex_one_embedding_with_metadata.tflite"; | ||||
| constexpr char kUniversalSentenceEncoderModel[] = | ||||
|     "universal_sentence_encoder_qa_with_metadata.tflite"; | ||||
| 
 | ||||
| // Tolerance for embedding vector coordinate values.
 | ||||
| constexpr float kEpsilon = 1e-4; | ||||
|  | @ -49,7 +51,7 @@ using ::mediapipe::file::JoinPath; | |||
| using ::testing::HasSubstr; | ||||
| using ::testing::Optional; | ||||
| 
 | ||||
| class EmbedderTest : public tflite_shims::testing::Test {}; | ||||
| class EmbedderTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(EmbedderTest, FailsWithMissingModel) { | ||||
|   auto text_embedder = | ||||
|  | @ -147,6 +149,35 @@ TEST_F(EmbedderTest, SucceedsWithQuantization) { | |||
|   MP_ASSERT_OK(text_embedder->Close()); | ||||
| } | ||||
| 
 | ||||
| TEST(EmbedTest, SucceedsWithUniversalSentenceEncoderModel) { | ||||
|   auto options = std::make_unique<TextEmbedderOptions>(); | ||||
|   options->base_options.model_asset_path = | ||||
|       JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder, | ||||
|                           TextEmbedder::Create(std::move(options))); | ||||
| 
 | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       auto result0, | ||||
|       text_embedder->Embed("it's a charming and often affecting journey")); | ||||
|   ASSERT_EQ(result0.embeddings.size(), 1); | ||||
|   ASSERT_EQ(result0.embeddings[0].float_embedding.size(), 100); | ||||
|   ASSERT_NEAR(result0.embeddings[0].float_embedding[0], 1.422951f, kEpsilon); | ||||
| 
 | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       auto result1, text_embedder->Embed("what a great and fantastic trip")); | ||||
|   ASSERT_EQ(result1.embeddings.size(), 1); | ||||
|   ASSERT_EQ(result1.embeddings[0].float_embedding.size(), 100); | ||||
|   ASSERT_NEAR(result1.embeddings[0].float_embedding[0], 1.404664f, kEpsilon); | ||||
| 
 | ||||
|   // Check cosine similarity.
 | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], | ||||
|                                                         result1.embeddings[0])); | ||||
|   ASSERT_NEAR(similarity, 0.851961, kSimilarityTolerancy); | ||||
| 
 | ||||
|   MP_ASSERT_OK(text_embedder->Close()); | ||||
| } | ||||
| 
 | ||||
| TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { | ||||
|   auto options = std::make_unique<TextEmbedderOptions>(); | ||||
|   options->base_options.model_asset_path = | ||||
|  | @ -178,5 +209,31 @@ TEST_F(EmbedderTest, SucceedsWithMobileBertAndDifferentThemes) { | |||
|   MP_ASSERT_OK(text_embedder->Close()); | ||||
| } | ||||
| 
 | ||||
| TEST_F(EmbedderTest, SucceedsWithUSEAndDifferentThemes) { | ||||
|   auto options = std::make_unique<TextEmbedderOptions>(); | ||||
|   options->base_options.model_asset_path = | ||||
|       JoinPath("./", kTestDataDirectory, kUniversalSentenceEncoderModel); | ||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextEmbedder> text_embedder, | ||||
|                           TextEmbedder::Create(std::move(options))); | ||||
| 
 | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       TextEmbedderResult result0, | ||||
|       text_embedder->Embed("When you go to this restaurant, they hold the " | ||||
|                            "pancake upside-down before they hand it " | ||||
|                            "to you. It's a great gimmick.")); | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       TextEmbedderResult result1, | ||||
|       text_embedder->Embed( | ||||
|           "Let's make a plan to steal the declaration of independence.")); | ||||
| 
 | ||||
|   // Check cosine similarity.
 | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|       double similarity, TextEmbedder::CosineSimilarity(result0.embeddings[0], | ||||
|                                                         result1.embeddings[0])); | ||||
|   EXPECT_NEAR(similarity, 0.780334, kSimilarityTolerancy); | ||||
| 
 | ||||
|   MP_ASSERT_OK(text_embedder->Close()); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace mediapipe::tasks::text::text_embedder
 | ||||
|  |  | |||
|  | @ -81,6 +81,6 @@ cc_test( | |||
|         "@com_google_absl//absl/flags:flag", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util", | ||||
|         "@org_tensorflow//tensorflow/lite:test_util", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -28,7 +28,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/components/processors/proto/text_model_type.pb.h" | ||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | ||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::text::utils { | ||||
| 
 | ||||
|  | @ -76,7 +76,7 @@ absl::StatusOr<TextModelType::ModelType> GetModelTypeFromFile( | |||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| class TextModelUtilsTest : public tflite_shims::testing::Test {}; | ||||
| class TextModelUtilsTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(TextModelUtilsTest, BertClassifierModelTest) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto model_type, | ||||
|  |  | |||
|  | @ -29,7 +29,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/core/proto/base_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/core/task_runner.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -105,7 +105,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() { | |||
|       graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>()); | ||||
| } | ||||
| 
 | ||||
| class FaceBlendshapesTest : public tflite_shims::testing::Test {}; | ||||
| class FaceBlendshapesTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(FaceBlendshapesTest, SmokeTest) { | ||||
|   // Prepare graph inputs.
 | ||||
|  |  | |||
|  | @ -67,7 +67,7 @@ constexpr char kPortraitExpectedFaceLandmarksName[] = | |||
|     "portrait_expected_face_landmarks.pbtxt"; | ||||
| constexpr char kPortraitExpectedBlendshapesName[] = | ||||
|     "portrait_expected_blendshapes.pbtxt"; | ||||
| constexpr char kPortaitExpectedFaceGeomertyName[] = | ||||
| constexpr char kPortraitExpectedFaceGeometryName[] = | ||||
|     "portrait_expected_face_geometry.pbtxt"; | ||||
| 
 | ||||
| constexpr float kLandmarksDiffMargin = 0.03; | ||||
|  | @ -100,7 +100,7 @@ struct FaceLandmarkerTestParams { | |||
| 
 | ||||
| mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() { | ||||
|   auto face_geometry = GetExpectedProto<face_geometry::proto::FaceGeometry>( | ||||
|       kPortaitExpectedFaceGeomertyName); | ||||
|       kPortraitExpectedFaceGeometryName); | ||||
|   return face_geometry.pose_transform_matrix(); | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -23,18 +23,12 @@ cc_library( | |||
|     srcs = ["face_stylizer_graph.cc"], | ||||
|     deps = [ | ||||
|         "//mediapipe/calculators/core:split_vector_calculator_cc_proto", | ||||
|         "//mediapipe/calculators/image:image_cropping_calculator", | ||||
|         "//mediapipe/calculators/image:image_cropping_calculator_cc_proto", | ||||
|         "//mediapipe/calculators/image:warp_affine_calculator", | ||||
|         "//mediapipe/calculators/image:warp_affine_calculator_cc_proto", | ||||
|         "//mediapipe/calculators/image:image_clone_calculator_cc_proto", | ||||
|         "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", | ||||
|         "//mediapipe/calculators/tensor:inference_calculator", | ||||
|         "//mediapipe/calculators/util:detections_to_rects_calculator", | ||||
|         "//mediapipe/calculators/util:face_to_rect_calculator", | ||||
|         "//mediapipe/calculators/util:from_image_calculator", | ||||
|         "//mediapipe/calculators/util:inverse_matrix_calculator", | ||||
|         "//mediapipe/calculators/util:landmarks_to_detection_calculator_cc_proto", | ||||
|         "//mediapipe/calculators/util:to_image_calculator", | ||||
|         "//mediapipe/framework/api2:builder", | ||||
|         "//mediapipe/framework/api2:port", | ||||
|         "//mediapipe/framework/formats:image", | ||||
|  | @ -53,7 +47,6 @@ cc_library( | |||
|         "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", | ||||
|         "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_stylizer/calculators:strip_rotation_calculator", | ||||
|         "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator", | ||||
|         "//mediapipe/tasks/cc/vision/face_stylizer/calculators:tensors_to_image_calculator_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", | ||||
|  |  | |||
|  | @ -84,9 +84,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { | |||
|   // The input image can be of any size with format RGB or RGBA.
 | ||||
|   // When no face is detected on the input image, the method returns a
 | ||||
|   // std::nullopt. Otherwise, returns the stylized image of the most visible
 | ||||
|   // face. To ensure that the output image has reasonable quality, the stylized
 | ||||
|   // output image size is the smaller of the model output size and the size of
 | ||||
|   // the 'region_of_interest' specified in 'image_processing_options'.
 | ||||
|   // face. The stylized output image size is the same as the model output size.
 | ||||
|   absl::StatusOr<std::optional<mediapipe::Image>> Stylize( | ||||
|       mediapipe::Image image, | ||||
|       std::optional<core::ImageProcessingOptions> image_processing_options = | ||||
|  | @ -111,9 +109,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { | |||
|   // must be monotonically increasing.
 | ||||
|   // When no face is detected on the input image, the method returns a
 | ||||
|   // std::nullopt. Otherwise, returns the stylized image of the most visible
 | ||||
|   // face. To ensure that the output image has reasonable quality, the stylized
 | ||||
|   // output image size is the smaller of the model output size and the size of
 | ||||
|   // the 'region_of_interest' specified in 'image_processing_options'.
 | ||||
|   // face. The stylized output image size is the same as the model output size.
 | ||||
|   absl::StatusOr<std::optional<mediapipe::Image>> StylizeForVideo( | ||||
|       mediapipe::Image image, int64_t timestamp_ms, | ||||
|       std::optional<core::ImageProcessingOptions> image_processing_options = | ||||
|  | @ -143,10 +139,8 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { | |||
|   // The "result_callback" provides:
 | ||||
|   //   - When no face is detected on the input image, the method returns a
 | ||||
|   //     std::nullopt. Otherwise, returns the stylized image of the most visible
 | ||||
|   //     face. To ensure that the output image has reasonable quality, the
 | ||||
|   //     stylized output image size is the smaller of the model output size and
 | ||||
|   //     the size of the 'region_of_interest' specified in
 | ||||
|   //     'image_processing_options'.
 | ||||
|   //     face. The stylized output image size is the same as the model output
 | ||||
|   //     size.
 | ||||
|   //   - The input timestamp in milliseconds.
 | ||||
|   absl::Status StylizeAsync(mediapipe::Image image, int64_t timestamp_ms, | ||||
|                             std::optional<core::ImageProcessingOptions> | ||||
|  |  | |||
|  | @ -19,8 +19,7 @@ limitations under the License. | |||
| #include "absl/memory/memory.h" | ||||
| #include "absl/status/statusor.h" | ||||
| #include "mediapipe/calculators/core/split_vector_calculator.pb.h" | ||||
| #include "mediapipe/calculators/image/image_cropping_calculator.pb.h" | ||||
| #include "mediapipe/calculators/image/warp_affine_calculator.pb.h" | ||||
| #include "mediapipe/calculators/image/image_clone_calculator.pb.h" | ||||
| #include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" | ||||
| #include "mediapipe/calculators/util/landmarks_to_detection_calculator.pb.h" | ||||
| #include "mediapipe/framework/api2/builder.h" | ||||
|  | @ -326,7 +325,6 @@ class FaceStylizerGraph : public core::ModelTaskGraph { | |||
|     image_in >> preprocessing.In(kImageTag); | ||||
|     face_rect >> preprocessing.In(kNormRectTag); | ||||
|     auto preprocessed_tensors = preprocessing.Out(kTensorsTag); | ||||
|     auto transform_matrix = preprocessing.Out(kMatrixTag); | ||||
| 
 | ||||
|     // Adds inference subgraph and connects its input stream to the output
 | ||||
|     // tensors produced by the ImageToTensorCalculator.
 | ||||
|  | @ -344,53 +342,12 @@ class FaceStylizerGraph : public core::ModelTaskGraph { | |||
|     model_output_tensors >> tensors_to_image.In(kTensorsTag); | ||||
|     auto tensor_image = tensors_to_image.Out(kImageTag); | ||||
| 
 | ||||
|     auto& inverse_matrix = graph.AddNode("InverseMatrixCalculator"); | ||||
|     transform_matrix >> inverse_matrix.In(kMatrixTag); | ||||
|     auto inverse_transform_matrix = inverse_matrix.Out(kMatrixTag); | ||||
|     auto& image_converter = graph.AddNode("ImageCloneCalculator"); | ||||
|     image_converter.GetOptions<mediapipe::ImageCloneCalculatorOptions>() | ||||
|         .set_output_on_gpu(false); | ||||
|     tensor_image >> image_converter.In(""); | ||||
| 
 | ||||
|     auto& warp_affine = graph.AddNode("WarpAffineCalculator"); | ||||
|     auto& warp_affine_options = | ||||
|         warp_affine.GetOptions<WarpAffineCalculatorOptions>(); | ||||
|     warp_affine_options.set_border_mode( | ||||
|         WarpAffineCalculatorOptions::BORDER_ZERO); | ||||
|     warp_affine_options.set_gpu_origin(mediapipe::GpuOrigin_Mode_TOP_LEFT); | ||||
|     tensor_image >> warp_affine.In(kImageTag); | ||||
|     inverse_transform_matrix >> warp_affine.In(kMatrixTag); | ||||
|     image_size >> warp_affine.In(kOutputSizeTag); | ||||
|     auto image_to_crop = warp_affine.Out(kImageTag); | ||||
| 
 | ||||
|     // The following calculators are for cropping and resizing the output image
 | ||||
|     // based on the roi and the model output size. As the WarpAffineCalculator
 | ||||
|     // rotates the image based on the transform matrix, the rotation info in the
 | ||||
|     // rect proto is stripped to prevent the ImageCroppingCalculator from
 | ||||
|     // performing extra rotation.
 | ||||
|     auto& strip_rotation = | ||||
|         graph.AddNode("mediapipe.tasks.StripRotationCalculator"); | ||||
|     face_rect >> strip_rotation.In(kNormRectTag); | ||||
|     auto norm_rect_no_rotation = strip_rotation.Out(kNormRectTag); | ||||
|     auto& from_image = graph.AddNode("FromImageCalculator"); | ||||
|     image_to_crop >> from_image.In(kImageTag); | ||||
|     auto& image_cropping = graph.AddNode("ImageCroppingCalculator"); | ||||
|     auto& image_cropping_opts = | ||||
|         image_cropping.GetOptions<ImageCroppingCalculatorOptions>(); | ||||
|     image_cropping_opts.set_output_max_width( | ||||
|         image_to_tensor_options.output_tensor_width()); | ||||
|     image_cropping_opts.set_output_max_height( | ||||
|         image_to_tensor_options.output_tensor_height()); | ||||
|     norm_rect_no_rotation >> image_cropping.In(kNormRectTag); | ||||
|     auto& to_image = graph.AddNode("ToImageCalculator"); | ||||
|     // ImageCroppingCalculator currently doesn't support mediapipe::Image, the
 | ||||
|     // graph selects its cpu or gpu path based on the image preprocessing
 | ||||
|     // backend.
 | ||||
|     if (use_gpu) { | ||||
|       from_image.Out(kImageGpuTag) >> image_cropping.In(kImageGpuTag); | ||||
|       image_cropping.Out(kImageGpuTag) >> to_image.In(kImageGpuTag); | ||||
|     } else { | ||||
|       from_image.Out(kImageCpuTag) >> image_cropping.In(kImageTag); | ||||
|       image_cropping.Out(kImageTag) >> to_image.In(kImageCpuTag); | ||||
|     } | ||||
| 
 | ||||
|     return {{/*stylized_image=*/to_image.Out(kImageTag).Cast<Image>(), | ||||
|     return {{/*stylized_image=*/image_converter.Out("").Cast<Image>(), | ||||
|              /*original_image=*/preprocessing.Out(kImageTag).Cast<Image>()}}; | ||||
|   } | ||||
| }; | ||||
|  |  | |||
|  | @ -43,7 +43,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/hand_landmarker/proto/hand_landmarks_detector_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -137,7 +137,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner() { | |||
|       graph.GetConfig(), absl::make_unique<core::MediaPipeBuiltinOpResolver>()); | ||||
| } | ||||
| 
 | ||||
| class HandLandmarkerTest : public tflite_shims::testing::Test {}; | ||||
| class HandLandmarkerTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(HandLandmarkerTest, Succeeds) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  |  | |||
|  | @ -41,7 +41,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" | ||||
| #include "mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_result.h" | ||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  |  | |||
|  | @ -59,7 +59,6 @@ using ::mediapipe::api2::Output; | |||
| using ::mediapipe::api2::builder::Graph; | ||||
| using ::mediapipe::api2::builder::Source; | ||||
| using ::mediapipe::tasks::components::utils::AllowIf; | ||||
| using ::mediapipe::tasks::core::ModelResources; | ||||
| using ::mediapipe::tasks::vision::hand_landmarker::proto:: | ||||
|     HandLandmarksDetectorGraphOptions; | ||||
| using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; | ||||
|  |  | |||
|  | @ -146,7 +146,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateSingleHandTaskRunner( | |||
| 
 | ||||
|   return TaskRunner::Create( | ||||
|       graph.GetConfig(), | ||||
|       absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); | ||||
|       absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); | ||||
| } | ||||
| 
 | ||||
| // Helper function to create a Multi Hand Landmark TaskRunner.
 | ||||
|  | @ -188,7 +188,7 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateMultiHandTaskRunner( | |||
| 
 | ||||
|   return TaskRunner::Create( | ||||
|       graph.GetConfig(), | ||||
|       absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>()); | ||||
|       absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); | ||||
| } | ||||
| 
 | ||||
| NormalizedLandmarkList GetExpectedLandmarkList(absl::string_view filename) { | ||||
|  |  | |||
|  | @ -39,9 +39,9 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/core/running_mode.h" | ||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/kernels/builtin_op_kernels.h" | ||||
| #include "tensorflow/lite/mutable_op_resolver.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -148,7 +148,7 @@ class MobileNetQuantizedOpResolverMissingOps | |||
|       const MobileNetQuantizedOpResolverMissingOps& r) = delete; | ||||
| }; | ||||
| 
 | ||||
| class CreateTest : public tflite_shims::testing::Test {}; | ||||
| class CreateTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) { | ||||
|   auto options = std::make_unique<ImageClassifierOptions>(); | ||||
|  | @ -265,7 +265,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) { | |||
|                   MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); | ||||
| } | ||||
| 
 | ||||
| class ImageModeTest : public tflite_shims::testing::Test {}; | ||||
| class ImageModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -605,7 +605,7 @@ TEST_F(ImageModeTest, FailsWithInvalidImageProcessingOptions) { | |||
|           MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); | ||||
| } | ||||
| 
 | ||||
| class VideoModeTest : public tflite_shims::testing::Test {}; | ||||
| class VideoModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -707,7 +707,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { | |||
|   MP_ASSERT_OK(image_classifier->Close()); | ||||
| } | ||||
| 
 | ||||
| class LiveStreamModeTest : public tflite_shims::testing::Test {}; | ||||
| class LiveStreamModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  |  | |||
|  | @ -30,9 +30,9 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/core/running_mode.h" | ||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/kernels/builtin_op_kernels.h" | ||||
| #include "tensorflow/lite/mutable_op_resolver.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -103,7 +103,7 @@ class MobileNetV3OpResolverMissingOps : public ::tflite::MutableOpResolver { | |||
|       delete; | ||||
| }; | ||||
| 
 | ||||
| class CreateTest : public tflite_shims::testing::Test {}; | ||||
| class CreateTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(CreateTest, SucceedsWithSelectiveOpResolver) { | ||||
|   auto options = std::make_unique<ImageEmbedderOptions>(); | ||||
|  | @ -181,7 +181,7 @@ TEST_F(CreateTest, FailsWithMissingCallbackInLiveStreamMode) { | |||
|                   MediaPipeTasksStatus::kInvalidTaskGraphConfigError)))); | ||||
| } | ||||
| 
 | ||||
| class ImageModeTest : public tflite_shims::testing::Test {}; | ||||
| class ImageModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -410,7 +410,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { | |||
|   EXPECT_LE(abs(similarity - expected_similarity), kSimilarityTolerancy); | ||||
| } | ||||
| 
 | ||||
| class VideoModeTest : public tflite_shims::testing::Test {}; | ||||
| class VideoModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -494,7 +494,7 @@ TEST_F(VideoModeTest, Succeeds) { | |||
|   MP_ASSERT_OK(image_embedder->Close()); | ||||
| } | ||||
| 
 | ||||
| class LiveStreamModeTest : public tflite_shims::testing::Test {}; | ||||
| class LiveStreamModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  |  | |||
|  | @ -39,9 +39,9 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_result.h" | ||||
| #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/kernels/builtin_op_kernels.h" | ||||
| #include "tensorflow/lite/mutable_op_resolver.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace tasks { | ||||
|  | @ -180,7 +180,7 @@ class DeepLabOpResolver : public ::tflite::MutableOpResolver { | |||
|   DeepLabOpResolver(const DeepLabOpResolver& r) = delete; | ||||
| }; | ||||
| 
 | ||||
| class CreateFromOptionsTest : public tflite_shims::testing::Test {}; | ||||
| class CreateFromOptionsTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { | ||||
|  public: | ||||
|  | @ -268,7 +268,7 @@ TEST(GetLabelsTest, SucceedsWithLabelsInModel) { | |||
|   } | ||||
| } | ||||
| 
 | ||||
| class ImageModeTest : public tflite_shims::testing::Test {}; | ||||
| class ImageModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ImageModeTest, SucceedsWithCategoryMask) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -521,7 +521,7 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { | |||
|               SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); | ||||
| } | ||||
| 
 | ||||
| class VideoModeTest : public tflite_shims::testing::Test {}; | ||||
| class VideoModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN( | ||||
|  | @ -581,7 +581,7 @@ TEST_F(VideoModeTest, Succeeds) { | |||
|   MP_ASSERT_OK(segmenter->Close()); | ||||
| } | ||||
| 
 | ||||
| class LiveStreamModeTest : public tflite_shims::testing::Test {}; | ||||
| class LiveStreamModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( | ||||
|  |  | |||
|  | @ -64,7 +64,6 @@ using ::mediapipe::CalculatorGraphConfig; | |||
| using ::mediapipe::Image; | ||||
| using ::mediapipe::NormalizedRect; | ||||
| using ::mediapipe::tasks::vision::image_segmenter::ImageSegmenterResult; | ||||
| using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; | ||||
| using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: | ||||
|     image_segmenter::proto::ImageSegmenterGraphOptions; | ||||
| 
 | ||||
|  |  | |||
|  | @ -39,9 +39,9 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/kernels/builtin_op_kernels.h" | ||||
| #include "tensorflow/lite/mutable_op_resolver.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| #include "testing/base/public/gmock.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
|  | @ -124,7 +124,7 @@ MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold, | |||
|          similarity_threshold; | ||||
| } | ||||
| 
 | ||||
| class CreateFromOptionsTest : public tflite_shims::testing::Test {}; | ||||
| class CreateFromOptionsTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { | ||||
|  public: | ||||
|  | @ -261,7 +261,7 @@ INSTANTIATE_TEST_SUITE_P( | |||
|     [](const ::testing::TestParamInfo<SucceedSegmentationWithRoi::ParamType>& | ||||
|            info) { return info.param.test_name; }); | ||||
| 
 | ||||
| class ImageModeTest : public tflite_shims::testing::Test {}; | ||||
| class ImageModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| // TODO: fix this unit test after image segmenter handled post
 | ||||
| // processing correctly with rotated image.
 | ||||
|  |  | |||
|  | @ -43,9 +43,9 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||
| #include "tensorflow/lite/c/common.h" | ||||
| #include "tensorflow/lite/core/api/op_resolver.h" | ||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||
| #include "tensorflow/lite/kernels/builtin_op_kernels.h" | ||||
| #include "tensorflow/lite/mutable_op_resolver.h" | ||||
| #include "tensorflow/lite/test_util.h" | ||||
| 
 | ||||
| namespace tflite { | ||||
| namespace ops { | ||||
|  | @ -159,7 +159,7 @@ class MobileSsdQuantizedOpResolver : public ::tflite::MutableOpResolver { | |||
|   MobileSsdQuantizedOpResolver(const MobileSsdQuantizedOpResolver& r) = delete; | ||||
| }; | ||||
| 
 | ||||
| class CreateFromOptionsTest : public tflite_shims::testing::Test {}; | ||||
| class CreateFromOptionsTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(CreateFromOptionsTest, SucceedsWithSelectiveOpResolver) { | ||||
|   auto options = std::make_unique<ObjectDetectorOptions>(); | ||||
|  | @ -332,7 +332,7 @@ TEST_F(CreateFromOptionsTest, InputTensorSpecsForEfficientDetModel) { | |||
| // TODO: Add NumThreadsTest back after having an
 | ||||
| // "acceleration configuration" field in the ObjectDetectorOptions.
 | ||||
| 
 | ||||
| class ImageModeTest : public tflite_shims::testing::Test {}; | ||||
| class ImageModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(ImageModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( | ||||
|  | @ -618,7 +618,7 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { | |||
|           MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); | ||||
| } | ||||
| 
 | ||||
| class VideoModeTest : public tflite_shims::testing::Test {}; | ||||
| class VideoModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(VideoModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( | ||||
|  | @ -673,7 +673,7 @@ TEST_F(VideoModeTest, Succeeds) { | |||
|   MP_ASSERT_OK(object_detector->Close()); | ||||
| } | ||||
| 
 | ||||
| class LiveStreamModeTest : public tflite_shims::testing::Test {}; | ||||
| class LiveStreamModeTest : public tflite::testing::Test {}; | ||||
| 
 | ||||
| TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { | ||||
|   MP_ASSERT_OK_AND_ASSIGN(Image image, DecodeImageFromFile(JoinPath( | ||||
|  |  | |||
|  | @ -97,8 +97,10 @@ cc_library( | |||
|         "//mediapipe/tasks/cc/core:model_task_graph", | ||||
|         "//mediapipe/tasks/cc/vision/pose_landmarker/proto:pose_landmarks_detector_graph_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", | ||||
|         "//mediapipe/util:graph_builder_utils", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|     ], | ||||
|     alwayslink = 1, | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|  |  | |||
|  | @ -73,14 +73,12 @@ constexpr int kMicroSecondsPerMilliSecond = 1000; | |||
| // limit the number of frames in flight.
 | ||||
| CalculatorGraphConfig CreateGraphConfig( | ||||
|     std::unique_ptr<PoseLandmarkerGraphOptionsProto> options, | ||||
|     bool enable_flow_limiting) { | ||||
|     bool enable_flow_limiting, bool output_segmentation_masks) { | ||||
|   api2::builder::Graph graph; | ||||
|   auto& subgraph = graph.AddNode(kPoseLandmarkerGraphTypeName); | ||||
|   subgraph.GetOptions<PoseLandmarkerGraphOptionsProto>().Swap(options.get()); | ||||
|   graph.In(kImageTag).SetName(kImageInStreamName); | ||||
|   graph.In(kNormRectTag).SetName(kNormRectStreamName); | ||||
|   subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> | ||||
|       graph.Out(kSegmentationMaskTag); | ||||
|   subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >> | ||||
|       graph.Out(kNormLandmarksTag); | ||||
|   subgraph.Out(kPoseWorldLandmarksTag).SetName(kPoseWorldLandmarksStreamName) >> | ||||
|  | @ -89,6 +87,10 @@ CalculatorGraphConfig CreateGraphConfig( | |||
|           .SetName(kPoseAuxiliaryLandmarksStreamName) >> | ||||
|       graph.Out(kPoseAuxiliaryLandmarksTag); | ||||
|   subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); | ||||
|   if (output_segmentation_masks) { | ||||
|     subgraph.Out(kSegmentationMaskTag).SetName(kSegmentationMaskStreamName) >> | ||||
|         graph.Out(kSegmentationMaskTag); | ||||
|   } | ||||
|   if (enable_flow_limiting) { | ||||
|     return tasks::core::AddFlowLimiterCalculator( | ||||
|         graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag); | ||||
|  | @ -187,7 +189,8 @@ absl::StatusOr<std::unique_ptr<PoseLandmarker>> PoseLandmarker::Create( | |||
|                                           PoseLandmarkerGraphOptionsProto>( | ||||
|           CreateGraphConfig( | ||||
|               std::move(options_proto), | ||||
|               options->running_mode == core::RunningMode::LIVE_STREAM), | ||||
|               options->running_mode == core::RunningMode::LIVE_STREAM, | ||||
|               options->output_segmentation_masks), | ||||
|           std::move(options->base_options.op_resolver), options->running_mode, | ||||
|           std::move(packets_callback)))); | ||||
| 
 | ||||
|  |  | |||
|  | @ -90,7 +90,7 @@ struct PoseLandmarkerOutputs { | |||
|   Source<std::vector<NormalizedLandmarkList>> auxiliary_landmark_lists; | ||||
|   Source<std::vector<NormalizedRect>> pose_rects_next_frame; | ||||
|   Source<std::vector<Detection>> pose_detections; | ||||
|   Source<std::vector<Image>> segmentation_masks; | ||||
|   std::optional<Source<std::vector<Image>>> segmentation_masks; | ||||
|   Source<Image> image; | ||||
| }; | ||||
| 
 | ||||
|  | @ -183,8 +183,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, | |||
| //   input_stream: "IMAGE:image_in"
 | ||||
| //   input_stream: "NORM_RECT:norm_rect"
 | ||||
| //   output_stream: "NORM_LANDMARKS:pose_landmarks"
 | ||||
| //   output_stream: "LANDMARKS:world_landmarks"
 | ||||
| //   output_stream: "NORM_LANDMAKRS:auxiliary_landmarks"
 | ||||
| //   output_stream: "WORLD_LANDMARKS:world_landmarks"
 | ||||
| //   output_stream: "AUXILIARY_LANDMARKS:auxiliary_landmarks"
 | ||||
| //   output_stream: "POSE_RECTS_NEXT_FRAME:pose_rects_next_frame"
 | ||||
| //   output_stream: "POSE_RECTS:pose_rects"
 | ||||
| //   output_stream: "SEGMENTATION_MASK:segmentation_masks"
 | ||||
|  | @ -212,6 +212,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { | |||
|   absl::StatusOr<CalculatorGraphConfig> GetConfig( | ||||
|       SubgraphContext* sc) override { | ||||
|     Graph graph; | ||||
|     bool output_segmentation_masks = | ||||
|         HasOutput(sc->OriginalNode(), kSegmentationMaskTag); | ||||
|     if (sc->Options<PoseLandmarkerGraphOptions>() | ||||
|             .base_options() | ||||
|             .has_model_asset()) { | ||||
|  | @ -226,12 +228,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { | |||
|           !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) | ||||
|                .IsAvailable())); | ||||
|     } | ||||
|     ASSIGN_OR_RETURN( | ||||
|         auto outs, | ||||
|     ASSIGN_OR_RETURN(auto outs, | ||||
|                      BuildPoseLandmarkerGraph( | ||||
|                          *sc->MutableOptions<PoseLandmarkerGraphOptions>(), | ||||
|                          graph[Input<Image>(kImageTag)], | ||||
|             graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); | ||||
|                          graph[Input<NormalizedRect>::Optional(kNormRectTag)], | ||||
|                          graph, output_segmentation_masks)); | ||||
|     outs.landmark_lists >> | ||||
|         graph[Output<std::vector<NormalizedLandmarkList>>(kNormLandmarksTag)]; | ||||
|     outs.world_landmark_lists >> | ||||
|  | @ -241,11 +243,13 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { | |||
|             kAuxiliaryLandmarksTag)]; | ||||
|     outs.pose_rects_next_frame >> | ||||
|         graph[Output<std::vector<NormalizedRect>>(kPoseRectsNextFrameTag)]; | ||||
|     outs.segmentation_masks >> | ||||
|         graph[Output<std::vector<Image>>(kSegmentationMaskTag)]; | ||||
|     outs.pose_detections >> | ||||
|         graph[Output<std::vector<Detection>>(kDetectionsTag)]; | ||||
|     outs.image >> graph[Output<Image>(kImageTag)]; | ||||
|     if (outs.segmentation_masks) { | ||||
|       *outs.segmentation_masks >> | ||||
|           graph[Output<std::vector<Image>>(kSegmentationMaskTag)]; | ||||
|     } | ||||
| 
 | ||||
|     // TODO remove when support is fixed.
 | ||||
|     // As mediapipe GraphBuilder currently doesn't support configuring
 | ||||
|  | @ -272,7 +276,8 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { | |||
|   // graph: the mediapipe graph instance to be updated.
 | ||||
|   absl::StatusOr<PoseLandmarkerOutputs> BuildPoseLandmarkerGraph( | ||||
|       PoseLandmarkerGraphOptions& tasks_options, Source<Image> image_in, | ||||
|       Source<NormalizedRect> norm_rect_in, Graph& graph) { | ||||
|       Source<NormalizedRect> norm_rect_in, Graph& graph, | ||||
|       bool output_segmentation_masks) { | ||||
|     const int max_num_poses = | ||||
|         tasks_options.pose_detector_graph_options().num_poses(); | ||||
| 
 | ||||
|  | @ -307,9 +312,12 @@ class PoseLandmarkerGraph : public core::ModelTaskGraph { | |||
|     auto pose_rects_for_next_frame = | ||||
|         pose_landmarks_detector_graph.Out(kPoseRectsNextFrameTag) | ||||
|             .Cast<std::vector<NormalizedRect>>(); | ||||
|     auto segmentation_masks = | ||||
|     std::optional<Source<std::vector<Image>>> segmentation_masks; | ||||
|     if (output_segmentation_masks) { | ||||
|       segmentation_masks = | ||||
|           pose_landmarks_detector_graph.Out(kSegmentationMaskTag) | ||||
|               .Cast<std::vector<Image>>(); | ||||
|     } | ||||
| 
 | ||||
|     if (tasks_options.base_options().use_stream_mode()) { | ||||
|       auto& previous_loopback = graph.AddNode("PreviousLoopbackCalculator"); | ||||
|  |  | |||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue
	
	Block a user