diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index dd2870e09..4af094f13 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -370,7 +370,6 @@ cc_library( ":pack_media_sequence_calculator_cc_proto", "//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", @@ -932,7 +931,6 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework:packet", "//mediapipe/framework:timestamp", - "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 9185e22a5..7a1f24722 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -22,7 +23,6 @@ #include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h" #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location_opencv.h" @@ -61,7 +61,7 @@ namespace mpms = mediapipe::mediasequence; // The supported input stream tags are: // * "IMAGE", which stores the encoded images from the // OpenCVImageEncoderCalculator, -// * "IMAGE_LABEL", which stores image labels from vector, +// * "IMAGE_LABEL", which stores whole image labels from Detection, // * "FORWARD_FLOW_ENCODED", which stores the encoded optical flow from the same // calculator, // * "BBOX" which stores bounding boxes from vector, @@ -124,7 +124,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { for (const auto& tag : cc->Inputs().GetTags()) { if (absl::StartsWith(tag, kImageTag)) { if (absl::StartsWith(tag, kImageLabelPrefixTag)) { - cc->Inputs().Tag(tag).Set>(); + cc->Inputs().Tag(tag).Set(); continue; } std::string key = ""; @@ -377,19 +377,29 @@ class PackMediaSequenceCalculator : public CalculatorBase { if (absl::StartsWith(tag, kImageLabelPrefixTag)) { std::string key = std::string(absl::StripPrefix(tag, kImageLabelPrefixTag)); - std::vector labels; - std::vector confidences; - for (const auto& classification : - cc->Inputs().Tag(tag).Get>()) { - labels.push_back(classification.label()); - confidences.push_back(classification.score()); + const auto& detection = cc->Inputs().Tag(tag).Get(); + if (detection.label().empty()) continue; + RET_CHECK(detection.label_size() == detection.score_size()) + << "Wrong image label data format: " << detection.label_size() + << " vs " << detection.score_size(); + if (!detection.label_id().empty()) { + RET_CHECK(detection.label_id_size() == detection.label_size()) + << "Wrong image label ID format: " << detection.label_id_size() + << " vs " << detection.label_size(); } + std::vector labels(detection.label().begin(), + detection.label().end()); + std::vector confidences(detection.score().begin(), + detection.score().end()); + std::vector ids(detection.label_id().begin(), + detection.label_id().end()); if (!key.empty() || mpms::HasImageEncoded(*sequence_)) { mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(), sequence_.get()); } mpms::AddImageLabelString(key, labels, sequence_.get()); mpms::AddImageLabelConfidence(key, confidences, sequence_.get()); + if (!ids.empty()) mpms::AddImageLabelIndex(key, ids, sequence_.get()); continue; } if (tag != kImageTag) { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index fa3e0bdea..a91074f07 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -21,7 +22,6 @@ #include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" -#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location_opencv.h" @@ -329,21 +329,27 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) { int num_timesteps = 2; for (int i = 0; i < num_timesteps; ++i) { - Classification cls; - cls.set_label(absl::StrCat("foo", 2 << i)); - cls.set_score(0.1 * i); - auto label_ptr = ::absl::make_unique>(2, cls); + Detection detection1; + detection1.add_label(absl::StrCat("foo", 2 << i)); + detection1.add_label_id(i); + detection1.add_score(0.1 * i); + detection1.add_label(absl::StrCat("foo", 2 << i)); + detection1.add_label_id(i); + detection1.add_score(0.1 * i); + auto label_ptr1 = ::absl::make_unique(detection1); runner_->MutableInputs() ->Tag(kImageLabelTestTag) - .packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i))); - cls.set_label(absl::StrCat("bar", 2 << i)); - cls.set_score(0.2 * i); - label_ptr = ::absl::make_unique>(2, cls); + .packets.push_back(Adopt(label_ptr1.release()).At(Timestamp(i))); + Detection detection2; + detection2.add_label(absl::StrCat("bar", 2 << i)); + detection2.add_score(0.2 * i); + detection2.add_label(absl::StrCat("bar", 2 << i)); + detection2.add_score(0.2 * i); + auto label_ptr2 = ::absl::make_unique(detection2); runner_->MutableInputs() ->Tag(kImageLabelOtherTag) - .packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i))); + .packets.push_back(Adopt(label_ptr2.release()).At(Timestamp(i))); } - runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = Adopt(input_sequence.release()); @@ -372,6 +378,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) { ASSERT_THAT(mpms::GetImageLabelStringAt("TEST", output_sequence, i), ::testing::ElementsAreArray( std::vector(2, absl::StrCat("foo", 2 << i)))); + ASSERT_THAT(mpms::GetImageLabelIndexAt("TEST", output_sequence, i), + ::testing::ElementsAreArray(std::vector(2, i))); ASSERT_THAT(mpms::GetImageLabelConfidenceAt("TEST", output_sequence, i), ::testing::ElementsAreArray(std::vector(2, 0.1 * i))); ASSERT_EQ(i, mpms::GetImageTimestampAt("OTHER", output_sequence, i));