diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 2d6948671..374478457 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -366,15 +366,15 @@ cc_library( name = "pack_media_sequence_calculator", srcs = ["pack_media_sequence_calculator.cc"], deps = [ + ":pack_media_sequence_calculator_cc_proto", "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", - "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:status", "//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence_util", "@com_google_absl//absl/container:flat_hash_map", @@ -925,21 +925,21 @@ cc_test( srcs = ["pack_media_sequence_calculator_test.cc"], deps = [ ":pack_media_sequence_calculator", + ":pack_media_sequence_calculator_cc_proto", "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", - "//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", - "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/util/sequence:media_sequence", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@org_tensorflow//tensorflow/core:protos_all_cc", ], ) diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 4bb2093da..196b3d8b7 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -17,16 +17,16 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" +#include "absl/strings/strip.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location_opencv.h" -#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/framework/port/status.h" #include "mediapipe/util/sequence/media_sequence.h" #include "mediapipe/util/sequence/media_sequence_util.h" #include "tensorflow/core/example/example.pb.h" @@ -36,6 +36,7 @@ namespace mediapipe { const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; const char kImageTag[] = "IMAGE"; +const char kImageLabelPrefixTag[] = "IMAGE_LABEL_"; const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_"; const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_"; const char kIntFeaturePrefixTag[] = "INT_FEATURE_"; @@ -56,7 +57,8 @@ namespace mpms = mediapipe::mediasequence; // SequenceExample will conform to the description in media_sequence.h. // // The supported input stream tags are "IMAGE", which stores the encoded -// images from the OpenCVImageEncoderCalculator, "FORWARD_FLOW_ENCODED", which +// images from the OpenCVImageEncoderCalculator, "IMAGE_LABEL", which stores +// image labels from vector, "FORWARD_FLOW_ENCODED", which // stores the encoded optical flow from the same calculator, "BBOX" which stores // bounding boxes from vector, and streams with the // "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector's @@ -112,6 +114,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kImageTag)) { + if (absl::StartsWith(tag, kImageLabelPrefixTag)) { + cc->Inputs().Tag(tag).Set>(); + continue; + } std::string key = ""; if (tag != kImageTag) { int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; @@ -199,6 +205,16 @@ class PackMediaSequenceCalculator : public CalculatorBase { .replace_data_instead_of_append()) { for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kImageTag)) { + if (absl::StartsWith(tag, kImageLabelPrefixTag)) { + std::string key = + std::string(absl::StripPrefix(tag, kImageLabelPrefixTag)); + mpms::ClearImageLabelString(key, sequence_.get()); + mpms::ClearImageLabelConfidence(key, sequence_.get()); + if (!key.empty() || mpms::HasImageEncoded(*sequence_)) { + mpms::ClearImageTimestamp(key, sequence_.get()); + } + continue; + } std::string key = ""; if (tag != kImageTag) { int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; @@ -343,6 +359,24 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (absl::StartsWith(tag, kImageTag) && !cc->Inputs().Tag(tag).IsEmpty()) { std::string key = ""; + if (absl::StartsWith(tag, kImageLabelPrefixTag)) { + std::string key = + std::string(absl::StripPrefix(tag, kImageLabelPrefixTag)); + std::vector labels; + std::vector confidences; + for (const auto& classification : + cc->Inputs().Tag(tag).Get>()) { + labels.push_back(classification.label()); + confidences.push_back(classification.score()); + } + if (!key.empty() || mpms::HasImageEncoded(*sequence_)) { + mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(), + sequence_.get()); + } + mpms::AddImageLabelString(key, labels, sequence_.get()); + mpms::AddImageLabelConfidence(key, confidences, sequence_.get()); + continue; + } if (tag != kImageTag) { int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1; if (tag[tag_length] == '_') { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 9d45e38e2..166e19062 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -12,27 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include +#include -#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" -#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" -#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location_opencv.h" -#include "mediapipe/framework/port/gmock.h" -#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/opencv_imgcodecs_inc.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/util/sequence/media_sequence.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" namespace mediapipe { namespace { @@ -58,6 +58,8 @@ constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER"; constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST"; constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER"; constexpr char kIntFeatureTestTag[] = "INT_FEATURE_TEST"; +constexpr char kImageLabelTestTag[] = "IMAGE_LABEL_TEST"; +constexpr char kImageLabelOtherTag[] = "IMAGE_LABEL_OTHER"; constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; constexpr char kImageTag[] = "IMAGE"; @@ -313,6 +315,68 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBytesLists) { } } +TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) { + SetUpCalculator( + {"IMAGE_LABEL_TEST:test_labels", "IMAGE_LABEL_OTHER:test_labels2"}, {}, + false, true); + auto input_sequence = ::absl::make_unique(); + + int num_timesteps = 2; + for (int i = 0; i < num_timesteps; ++i) { + Classification cls; + cls.set_label(absl::StrCat("foo", 2 << i)); + cls.set_score(0.1 * i); + auto label_ptr = ::absl::make_unique>(2, cls); + runner_->MutableInputs() + ->Tag(kImageLabelTestTag) + .packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i))); + cls.set_label(absl::StrCat("bar", 2 << i)); + cls.set_score(0.2 * i); + label_ptr = ::absl::make_unique>(2, cls); + runner_->MutableInputs() + ->Tag(kImageLabelOtherTag) + .packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + MP_ASSERT_OK(runner_->Run()); + + const std::vector& output_packets = + runner_->Outputs().Tag(kSequenceExampleTag).packets; + ASSERT_EQ(1, output_packets.size()); + const tf::SequenceExample& output_sequence = + output_packets[0].Get(); + + ASSERT_EQ(num_timesteps, + mpms::GetImageTimestampSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelStringSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelConfidenceSize("TEST", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageTimestampSize("OTHER", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelStringSize("OTHER", output_sequence)); + ASSERT_EQ(num_timesteps, + mpms::GetImageLabelConfidenceSize("OTHER", output_sequence)); + for (int i = 0; i < num_timesteps; ++i) { + ASSERT_EQ(i, mpms::GetImageTimestampAt("TEST", output_sequence, i)); + ASSERT_THAT(mpms::GetImageLabelStringAt("TEST", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("foo", 2 << i)))); + ASSERT_THAT(mpms::GetImageLabelConfidenceAt("TEST", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, 0.1 * i))); + ASSERT_EQ(i, mpms::GetImageTimestampAt("OTHER", output_sequence, i)); + ASSERT_THAT(mpms::GetImageLabelStringAt("OTHER", output_sequence, i), + ::testing::ElementsAreArray( + std::vector(2, absl::StrCat("bar", 2 << i)))); + ASSERT_THAT(mpms::GetImageLabelConfidenceAt("OTHER", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, 0.2 * i))); + } +} + TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) { SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true); auto input_sequence = ::absl::make_unique();