Add a support for label annotations (image/label/string and image/label/confidence). Also fixed some clang tidy issues.
PiperOrigin-RevId: 553900667
This commit is contained in:
parent
11508f2291
commit
460346ed13
|
@ -366,15 +366,15 @@ cc_library(
|
||||||
name = "pack_media_sequence_calculator",
|
name = "pack_media_sequence_calculator",
|
||||||
srcs = ["pack_media_sequence_calculator.cc"],
|
srcs = ["pack_media_sequence_calculator.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":pack_media_sequence_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
|
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:location",
|
"//mediapipe/framework/formats:location",
|
||||||
"//mediapipe/framework/formats:location_opencv",
|
"//mediapipe/framework/formats:location_opencv",
|
||||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
|
||||||
"//mediapipe/util/sequence:media_sequence",
|
"//mediapipe/util/sequence:media_sequence",
|
||||||
"//mediapipe/util/sequence:media_sequence_util",
|
"//mediapipe/util/sequence:media_sequence_util",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
@ -925,21 +925,21 @@ cc_test(
|
||||||
srcs = ["pack_media_sequence_calculator_test.cc"],
|
srcs = ["pack_media_sequence_calculator_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":pack_media_sequence_calculator",
|
":pack_media_sequence_calculator",
|
||||||
|
":pack_media_sequence_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
|
"//mediapipe/calculators/image:opencv_image_encoder_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/tensorflow:pack_media_sequence_calculator_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/framework/formats:detection_cc_proto",
|
"//mediapipe/framework/formats:detection_cc_proto",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
|
||||||
"//mediapipe/framework/formats:location",
|
"//mediapipe/framework/formats:location",
|
||||||
"//mediapipe/framework/formats:location_opencv",
|
"//mediapipe/framework/formats:location_opencv",
|
||||||
"//mediapipe/framework/port:gtest_main",
|
"//mediapipe/framework/port:gtest_main",
|
||||||
"//mediapipe/framework/port:opencv_imgcodecs",
|
"//mediapipe/framework/port:opencv_imgcodecs",
|
||||||
"//mediapipe/util/sequence:media_sequence",
|
"//mediapipe/util/sequence:media_sequence",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
"@org_tensorflow//tensorflow/core:protos_all_cc",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,16 +17,16 @@
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
|
#include "absl/strings/strip.h"
|
||||||
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/location.h"
|
#include "mediapipe/framework/formats/location.h"
|
||||||
#include "mediapipe/framework/formats/location_opencv.h"
|
#include "mediapipe/framework/formats/location_opencv.h"
|
||||||
#include "mediapipe/framework/port/canonical_errors.h"
|
|
||||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
|
||||||
#include "mediapipe/util/sequence/media_sequence.h"
|
#include "mediapipe/util/sequence/media_sequence.h"
|
||||||
#include "mediapipe/util/sequence/media_sequence_util.h"
|
#include "mediapipe/util/sequence/media_sequence_util.h"
|
||||||
#include "tensorflow/core/example/example.pb.h"
|
#include "tensorflow/core/example/example.pb.h"
|
||||||
|
@ -36,6 +36,7 @@ namespace mediapipe {
|
||||||
|
|
||||||
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||||
const char kImageTag[] = "IMAGE";
|
const char kImageTag[] = "IMAGE";
|
||||||
|
const char kImageLabelPrefixTag[] = "IMAGE_LABEL_";
|
||||||
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
|
const char kFloatContextFeaturePrefixTag[] = "FLOAT_CONTEXT_FEATURE_";
|
||||||
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
const char kFloatFeaturePrefixTag[] = "FLOAT_FEATURE_";
|
||||||
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
|
const char kIntFeaturePrefixTag[] = "INT_FEATURE_";
|
||||||
|
@ -56,7 +57,8 @@ namespace mpms = mediapipe::mediasequence;
|
||||||
// SequenceExample will conform to the description in media_sequence.h.
|
// SequenceExample will conform to the description in media_sequence.h.
|
||||||
//
|
//
|
||||||
// The supported input stream tags are "IMAGE", which stores the encoded
|
// The supported input stream tags are "IMAGE", which stores the encoded
|
||||||
// images from the OpenCVImageEncoderCalculator, "FORWARD_FLOW_ENCODED", which
|
// images from the OpenCVImageEncoderCalculator, "IMAGE_LABEL", which stores
|
||||||
|
// image labels from vector<Classification>, "FORWARD_FLOW_ENCODED", which
|
||||||
// stores the encoded optical flow from the same calculator, "BBOX" which stores
|
// stores the encoded optical flow from the same calculator, "BBOX" which stores
|
||||||
// bounding boxes from vector<Detections>, and streams with the
|
// bounding boxes from vector<Detections>, and streams with the
|
||||||
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s
|
// "FLOAT_FEATURE_${NAME}" pattern, which stores the values from vector<float>'s
|
||||||
|
@ -112,6 +114,10 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
|
|
||||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||||
if (absl::StartsWith(tag, kImageTag)) {
|
if (absl::StartsWith(tag, kImageTag)) {
|
||||||
|
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
||||||
|
cc->Inputs().Tag(tag).Set<std::vector<Classification>>();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
std::string key = "";
|
std::string key = "";
|
||||||
if (tag != kImageTag) {
|
if (tag != kImageTag) {
|
||||||
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
|
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
|
||||||
|
@ -199,6 +205,16 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
.replace_data_instead_of_append()) {
|
.replace_data_instead_of_append()) {
|
||||||
for (const auto& tag : cc->Inputs().GetTags()) {
|
for (const auto& tag : cc->Inputs().GetTags()) {
|
||||||
if (absl::StartsWith(tag, kImageTag)) {
|
if (absl::StartsWith(tag, kImageTag)) {
|
||||||
|
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
||||||
|
std::string key =
|
||||||
|
std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
|
||||||
|
mpms::ClearImageLabelString(key, sequence_.get());
|
||||||
|
mpms::ClearImageLabelConfidence(key, sequence_.get());
|
||||||
|
if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
|
||||||
|
mpms::ClearImageTimestamp(key, sequence_.get());
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
std::string key = "";
|
std::string key = "";
|
||||||
if (tag != kImageTag) {
|
if (tag != kImageTag) {
|
||||||
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
|
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
|
||||||
|
@ -343,6 +359,24 @@ class PackMediaSequenceCalculator : public CalculatorBase {
|
||||||
if (absl::StartsWith(tag, kImageTag) &&
|
if (absl::StartsWith(tag, kImageTag) &&
|
||||||
!cc->Inputs().Tag(tag).IsEmpty()) {
|
!cc->Inputs().Tag(tag).IsEmpty()) {
|
||||||
std::string key = "";
|
std::string key = "";
|
||||||
|
if (absl::StartsWith(tag, kImageLabelPrefixTag)) {
|
||||||
|
std::string key =
|
||||||
|
std::string(absl::StripPrefix(tag, kImageLabelPrefixTag));
|
||||||
|
std::vector<std::string> labels;
|
||||||
|
std::vector<float> confidences;
|
||||||
|
for (const auto& classification :
|
||||||
|
cc->Inputs().Tag(tag).Get<std::vector<Classification>>()) {
|
||||||
|
labels.push_back(classification.label());
|
||||||
|
confidences.push_back(classification.score());
|
||||||
|
}
|
||||||
|
if (!key.empty() || mpms::HasImageEncoded(*sequence_)) {
|
||||||
|
mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
|
||||||
|
sequence_.get());
|
||||||
|
}
|
||||||
|
mpms::AddImageLabelString(key, labels, sequence_.get());
|
||||||
|
mpms::AddImageLabelConfidence(key, confidences, sequence_.get());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (tag != kImageTag) {
|
if (tag != kImageTag) {
|
||||||
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
|
int tag_length = sizeof(kImageTag) / sizeof(*kImageTag) - 1;
|
||||||
if (tag[tag_length] == '_') {
|
if (tag[tag_length] == '_') {
|
||||||
|
|
|
@ -12,27 +12,27 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <algorithm>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/numbers.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
#include "mediapipe/calculators/image/opencv_image_encoder_calculator.pb.h"
|
||||||
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
#include "mediapipe/calculators/tensorflow/pack_media_sequence_calculator.pb.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/framework/formats/detection.pb.h"
|
#include "mediapipe/framework/formats/detection.pb.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
|
||||||
#include "mediapipe/framework/formats/location.h"
|
#include "mediapipe/framework/formats/location.h"
|
||||||
#include "mediapipe/framework/formats/location_opencv.h"
|
#include "mediapipe/framework/formats/location_opencv.h"
|
||||||
#include "mediapipe/framework/port/gmock.h"
|
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
|
||||||
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
#include "mediapipe/framework/port/opencv_imgcodecs_inc.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/util/sequence/media_sequence.h"
|
#include "mediapipe/util/sequence/media_sequence.h"
|
||||||
#include "tensorflow/core/example/example.pb.h"
|
#include "tensorflow/core/example/example.pb.h"
|
||||||
#include "tensorflow/core/example/feature.pb.h"
|
#include "tensorflow/core/example/feature.pb.h"
|
||||||
|
#include "testing/base/public/gmock.h"
|
||||||
|
#include "testing/base/public/gunit.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -58,6 +58,8 @@ constexpr char kFloatFeatureOtherTag[] = "FLOAT_FEATURE_OTHER";
|
||||||
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
constexpr char kFloatFeatureTestTag[] = "FLOAT_FEATURE_TEST";
|
||||||
constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER";
|
constexpr char kIntFeatureOtherTag[] = "INT_FEATURE_OTHER";
|
||||||
constexpr char kIntFeatureTestTag[] = "INT_FEATURE_TEST";
|
constexpr char kIntFeatureTestTag[] = "INT_FEATURE_TEST";
|
||||||
|
constexpr char kImageLabelTestTag[] = "IMAGE_LABEL_TEST";
|
||||||
|
constexpr char kImageLabelOtherTag[] = "IMAGE_LABEL_OTHER";
|
||||||
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
|
constexpr char kImagePrefixTag[] = "IMAGE_PREFIX";
|
||||||
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE";
|
||||||
constexpr char kImageTag[] = "IMAGE";
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
@ -313,6 +315,68 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBytesLists) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImageLabels) {
|
||||||
|
SetUpCalculator(
|
||||||
|
{"IMAGE_LABEL_TEST:test_labels", "IMAGE_LABEL_OTHER:test_labels2"}, {},
|
||||||
|
false, true);
|
||||||
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
|
||||||
|
int num_timesteps = 2;
|
||||||
|
for (int i = 0; i < num_timesteps; ++i) {
|
||||||
|
Classification cls;
|
||||||
|
cls.set_label(absl::StrCat("foo", 2 << i));
|
||||||
|
cls.set_score(0.1 * i);
|
||||||
|
auto label_ptr = ::absl::make_unique<std::vector<Classification>>(2, cls);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kImageLabelTestTag)
|
||||||
|
.packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i)));
|
||||||
|
cls.set_label(absl::StrCat("bar", 2 << i));
|
||||||
|
cls.set_score(0.2 * i);
|
||||||
|
label_ptr = ::absl::make_unique<std::vector<Classification>>(2, cls);
|
||||||
|
runner_->MutableInputs()
|
||||||
|
->Tag(kImageLabelOtherTag)
|
||||||
|
.packets.push_back(Adopt(label_ptr.release()).At(Timestamp(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
|
||||||
|
Adopt(input_sequence.release());
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner_->Run());
|
||||||
|
|
||||||
|
const std::vector<Packet>& output_packets =
|
||||||
|
runner_->Outputs().Tag(kSequenceExampleTag).packets;
|
||||||
|
ASSERT_EQ(1, output_packets.size());
|
||||||
|
const tf::SequenceExample& output_sequence =
|
||||||
|
output_packets[0].Get<tf::SequenceExample>();
|
||||||
|
|
||||||
|
ASSERT_EQ(num_timesteps,
|
||||||
|
mpms::GetImageTimestampSize("TEST", output_sequence));
|
||||||
|
ASSERT_EQ(num_timesteps,
|
||||||
|
mpms::GetImageLabelStringSize("TEST", output_sequence));
|
||||||
|
ASSERT_EQ(num_timesteps,
|
||||||
|
mpms::GetImageLabelConfidenceSize("TEST", output_sequence));
|
||||||
|
ASSERT_EQ(num_timesteps,
|
||||||
|
mpms::GetImageTimestampSize("OTHER", output_sequence));
|
||||||
|
ASSERT_EQ(num_timesteps,
|
||||||
|
mpms::GetImageLabelStringSize("OTHER", output_sequence));
|
||||||
|
ASSERT_EQ(num_timesteps,
|
||||||
|
mpms::GetImageLabelConfidenceSize("OTHER", output_sequence));
|
||||||
|
for (int i = 0; i < num_timesteps; ++i) {
|
||||||
|
ASSERT_EQ(i, mpms::GetImageTimestampAt("TEST", output_sequence, i));
|
||||||
|
ASSERT_THAT(mpms::GetImageLabelStringAt("TEST", output_sequence, i),
|
||||||
|
::testing::ElementsAreArray(
|
||||||
|
std::vector<std::string>(2, absl::StrCat("foo", 2 << i))));
|
||||||
|
ASSERT_THAT(mpms::GetImageLabelConfidenceAt("TEST", output_sequence, i),
|
||||||
|
::testing::ElementsAreArray(std::vector<float>(2, 0.1 * i)));
|
||||||
|
ASSERT_EQ(i, mpms::GetImageTimestampAt("OTHER", output_sequence, i));
|
||||||
|
ASSERT_THAT(mpms::GetImageLabelStringAt("OTHER", output_sequence, i),
|
||||||
|
::testing::ElementsAreArray(
|
||||||
|
std::vector<std::string>(2, absl::StrCat("bar", 2 << i))));
|
||||||
|
ASSERT_THAT(mpms::GetImageLabelConfidenceAt("OTHER", output_sequence, i),
|
||||||
|
::testing::ElementsAreArray(std::vector<float>(2, 0.2 * i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) {
|
TEST_F(PackMediaSequenceCalculatorTest, OutputAsZeroTimestamp) {
|
||||||
SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true);
|
SetUpCalculator({"FLOAT_FEATURE_TEST:test"}, {}, false, true, true);
|
||||||
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
|
||||||
|
|
Loading…
Reference in New Issue
Block a user