Change the image label input from Classification to Detection.
PiperOrigin-RevId: 559828139
This commit is contained in:
parent
f2e9a553d6
commit
c56f45bce5
|
@ -370,7 +370,6 @@ cc_library(
|
||||||
":pack_media_sequence_calculator_cc_proto",
|
":pack_media_sequence_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
|
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:location",
|
"//mediapipe/framework/formats:location",
|
||||||
"//mediapipe/framework/formats:location_opencv",
|
"//mediapipe/framework/formats:location_opencv",
|
||||||
|
@ -932,7 +931,6 @@ cc_test(
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:location",
|
"//mediapipe/framework/formats:location",
|
||||||
"//mediapipe/framework/formats:location_opencv",
|
"//mediapipe/framework/formats:location_opencv",
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -22,7 +23,6 @@
|
||||||
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/location.h"
|
#include "mediapipe/framework/formats/location.h"
|
||||||
#include "mediapipe/framework/formats/location_opencv.h"
|
#include "mediapipe/framework/formats/location_opencv.h"
|
||||||
|
@ -61,7 +61,7 @@ namespace mpms = mediapipe::mediasequence;
|
||||||
// The supported input stream tags are:
|
// The supported input stream tags are:
|
||||||
// * "IMAGE", which stores the encoded images from the
|
// * "IMAGE", which stores the encoded images from the
|
||||||
// OpenCVImageEncoderCalculator,
|
// OpenCVImageEncoderCalculator,
|
||||||
// * "IMAGE_LABEL", which stores image labels from vector<Classification>,
|
// * "IMAGE_LABEL", which stores whole image labels from Detection,
|
||||||
// * "FORWARD_FLOW_ENCODED", which stores the encoded optical flow from the same
|
// * "FORWARD_FLOW_ENCODED", which stores the encoded optical flow from the same
|
||||||
// calculator,
|
// calculator,
|
||||||
// * "BBOX" which stores bounding boxes from vector<Detections>,
|
// * "BBOX" which stores bounding boxes from vector<Detections>,
|
||||||
|
@ -124,7 +124,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||||
if (absl::StartsWith(tag, kImageTag)) {
|
if (absl::StartsWith(tag, kImageTag)) {
|
||||||
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
||||||
cc->Inputs().Tag(tag).Set<std::vector<Classification>>();
|
cc->Inputs().Tag(tag).Set<Detection>();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::string key = "";
|
std::string key = "";
|
||||||
|
@ -377,19 +377,29 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
||||||
std::string key =
|
std::string key =
|
||||||
std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
|
std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
|
||||||
std::vector<std::string> labels;
|
const auto& detection = cc->Inputs().Tag(tag).Get<Detection>();
|
||||||
std::vector<float> confidences;
|
if (detection.label().empty()) continue;
|
||||||
for (const auto& classification :
|
RET_CHECK(detection.label_size() == detection.score_size())
|
||||||
cc->Inputs().Tag(tag).Get<std::vector<Classification>>()) {
|
<< "Wrong image label data format: " << detection.label_size()
|
||||||
labels.push_back(classification.label());
|
<< " vs " << detection.score_size();
|
||||||
confidences.push_back(classification.score());
|
if (!detection.label_id().empty()) {
|
||||||
|
RET_CHECK(detection.label_id_size() == detection.label_size())
|
||||||
|
<< "Wrong image label ID format: " << detection.label_id_size()
|
||||||
|
<< " vs " << detection.label_size();
|
||||||
}
|
}
|
||||||
|
std::vector<std::string> labels(detection.label().begin(),
|
||||||
|
detection.label().end());
|
||||||
|
std::vector<float> confidences(detection.score().begin(),
|
||||||
|
detection.score().end());
|
||||||
|
std::vector<int32_t> ids(detection.label_id().begin(),
|
||||||
|
detection.label_id().end());
|
||||||
if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
|
if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
|
||||||
mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
|
mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
|
||||||
sequence_.get());
|
sequence_.get());
|
||||||
}
|
}
|
||||||
mpms::AddImageLabelString(key, labels, sequence_.get());
|
mpms::AddImageLabelString(key, labels, sequence_.get());
|
||||||
mpms::AddImageLabelConfidence(key, confidences, sequence_.get());
|
mpms::AddImageLabelConfidence(key, confidences, sequence_.get());
|
||||||
|
if (!ids.empty()) mpms::AddImageLabelIndex(key, ids, sequence_.get());
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (tag != kImageTag) {
|
if (tag != kImageTag) {
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
@ -21,7 +22,6 @@
|
||||||
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/location.h"
|
#include "mediapipe/framework/formats/location.h"
|
||||||
#include "mediapipe/framework/formats/location_opencv.h"
|
#include "mediapipe/framework/formats/location_opencv.h"
|
||||||
|
@ -329,21 +329,27 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) {
|
||||||
|
|
||||||
int num_timesteps = 2;
|
int num_timesteps = 2;
|
||||||
for (int i = 0; i < num_timesteps; ++i) {
|
for (int i = 0; i < num_timesteps; ++i) {
|
||||||
Classification cls;
|
Detection detection1;
|
||||||
cls.set_label(absl::StrCat("foo", 2 << i));
|
detection1.add_label(absl::StrCat("foo", 2 << i));
|
||||||
cls.set_score(0.1 * i);
|
detection1.add_label_id(i);
|
||||||
auto label_ptr = ::absl::make_unique<std::vector<Classification>>(2, cls);
|
detection1.add_score(0.1 * i);
|
||||||
|
detection1.add_label(absl::StrCat("foo", 2 << i));
|
||||||
|
detection1.add_label_id(i);
|
||||||
|
detection1.add_score(0.1 * i);
|
||||||
|
auto label_ptr1 = ::absl::make_unique<Detection>(detection1);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag(kImageLabelTestTag)
|
->Tag(kImageLabelTestTag)
|
||||||
.packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(label_ptr1.release()).At(Timestamp(i)));
|
||||||
cls.set_label(absl::StrCat("bar", 2 << i));
|
Detection detection2;
|
||||||
cls.set_score(0.2 * i);
|
detection2.add_label(absl::StrCat("bar", 2 << i));
|
||||||
label_ptr = ::absl::make_unique<std::vector<Classification>>(2, cls);
|
detection2.add_score(0.2 * i);
|
||||||
|
detection2.add_label(absl::StrCat("bar", 2 << i));
|
||||||
|
detection2.add_score(0.2 * i);
|
||||||
|
auto label_ptr2 = ::absl::make_unique<Detection>(detection2);
|
||||||
runner_->MutableInputs()
|
runner_->MutableInputs()
|
||||||
->Tag(kImageLabelOtherTag)
|
->Tag(kImageLabelOtherTag)
|
||||||
.packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i)));
|
.packets.push_back(Adopt(label_ptr2.release()).At(Timestamp(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
Adopt(input_sequence.release());
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
@ -372,6 +378,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) {
|
||||||
ASSERT_THAT(mpms::GetImageLabelStringAt("TEST", output_sequence, i),
|
ASSERT_THAT(mpms::GetImageLabelStringAt("TEST", output_sequence, i),
|
||||||
::testing::ElementsAreArray(
|
::testing::ElementsAreArray(
|
||||||
std::vector<std::string>(2, absl::StrCat("foo", 2 << i))));
|
std::vector<std::string>(2, absl::StrCat("foo", 2 << i))));
|
||||||
|
ASSERT_THAT(mpms::GetImageLabelIndexAt("TEST", output_sequence, i),
|
||||||
|
::testing::ElementsAreArray(std::vector<int32_t>(2, i)));
|
||||||
ASSERT_THAT(mpms::GetImageLabelConfidenceAt("TEST", output_sequence, i),
|
ASSERT_THAT(mpms::GetImageLabelConfidenceAt("TEST", output_sequence, i),
|
||||||
::testing::ElementsAreArray(std::vector<float>(2, 0.1 * i)));
|
::testing::ElementsAreArray(std::vector<float>(2, 0.1 * i)));
|
||||||
ASSERT_EQ(i, mpms::GetImageTimestampAt("OTHER", output_sequence, i));
|
ASSERT_EQ(i, mpms::GetImageTimestampAt("OTHER", output_sequence, i));
|
||||||
|
|
Loading…
Reference in New Issue
Block a user