From d61ab92b90c5345e1bdd47b9545f745e6bed5ea8 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Wed, 2 Nov 2022 05:11:16 -0700 Subject: [PATCH] Add classification result structs and use then in ImageClassifier & TextClassifier. PiperOrigin-RevId: 485567133 --- .../tasks/cc/components/containers/BUILD | 20 + .../cc/components/containers/category.cc | 38 ++ .../tasks/cc/components/containers/category.h | 52 +++ .../containers/classification_result.cc | 57 +++ .../containers/classification_result.h | 68 ++++ mediapipe/tasks/cc/text/text_classifier/BUILD | 5 +- .../text/text_classifier/text_classifier.cc | 17 +- .../cc/text/text_classifier/text_classifier.h | 9 +- .../text_classifier/text_classifier_graph.cc | 30 +- .../text_classifier/text_classifier_test.cc | 118 +++++- .../tasks/cc/vision/image_classifier/BUILD | 1 + .../image_classifier/image_classifier.cc | 37 +- .../image_classifier/image_classifier.h | 23 +- .../image_classifier_graph.cc | 16 +- .../image_classifier/image_classifier_test.cc | 344 +++++++----------- 15 files changed, 559 insertions(+), 276 deletions(-) create mode 100644 mediapipe/tasks/cc/components/containers/category.cc create mode 100644 mediapipe/tasks/cc/components/containers/category.h create mode 100644 mediapipe/tasks/cc/components/containers/classification_result.cc create mode 100644 mediapipe/tasks/cc/components/containers/classification_result.h diff --git a/mediapipe/tasks/cc/components/containers/BUILD b/mediapipe/tasks/cc/components/containers/BUILD index af51d0c37..7a52f11e0 100644 --- a/mediapipe/tasks/cc/components/containers/BUILD +++ b/mediapipe/tasks/cc/components/containers/BUILD @@ -29,3 +29,23 @@ cc_library( "//mediapipe/framework/formats:landmark_cc_proto", ], ) + +cc_library( + name = "category", + srcs = ["category.cc"], + hdrs = ["category.h"], + deps = [ + "//mediapipe/framework/formats:classification_cc_proto", + ], +) + +cc_library( + name = "classification_result", + srcs = ["classification_result.cc"], + hdrs = ["classification_result.h"], + deps = [ + ":category", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + ], +) diff --git a/mediapipe/tasks/cc/components/containers/category.cc b/mediapipe/tasks/cc/components/containers/category.cc new file mode 100644 index 000000000..e07333a7b --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/category.cc @@ -0,0 +1,38 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/category.h" + +#include +#include + +#include "mediapipe/framework/formats/classification.pb.h" + +namespace mediapipe::tasks::components::containers { + +Category ConvertToCategory(const mediapipe::Classification& proto) { + Category category; + category.index = proto.index(); + category.score = proto.score(); + if (proto.has_label()) { + category.category_name = proto.label(); + } + if (proto.has_display_name()) { + category.display_name = proto.display_name(); + } + return category; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/category.h b/mediapipe/tasks/cc/components/containers/category.h new file mode 100644 index 000000000..57b18e7ea --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/category.h @@ -0,0 +1,52 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_ + +#include +#include + +#include "mediapipe/framework/formats/classification.pb.h" + +namespace mediapipe::tasks::components::containers { + +// Defines a single classification result. +// +// The label maps packed into the TFLite Model Metadata [1] are used to populate +// the 'category_name' and 'display_name' fields. +// +// [1]: https://www.tensorflow.org/lite/convert/metadata +struct Category { + // The index of the category in the classification model output. + int index; + // The score for this category, e.g. (but not necessarily) a probability in + // [0,1]. + float score; + // The optional ID for the category, read from the label map packed in the + // TFLite Model Metadata if present. Not necessarily human-readable. + std::optional category_name = std::nullopt; + // The optional human-readable name for the category, read from the label map + // packed in the TFLite Model Metadata if present. + std::optional display_name = std::nullopt; +}; + +// Utility function to convert from mediapipe::Classification proto to Category +// struct. +Category ConvertToCategory(const mediapipe::Classification& proto); + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_ diff --git a/mediapipe/tasks/cc/components/containers/classification_result.cc b/mediapipe/tasks/cc/components/containers/classification_result.cc new file mode 100644 index 000000000..98583ff15 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/classification_result.cc @@ -0,0 +1,57 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/components/containers/classification_result.h" + +#include +#include +#include + +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace mediapipe::tasks::components::containers { + +Classifications ConvertToClassifications(const proto::Classifications& proto) { + Classifications classifications; + classifications.categories.reserve( + proto.classification_list().classification_size()); + for (const auto& classification : + proto.classification_list().classification()) { + classifications.categories.push_back(ConvertToCategory(classification)); + } + classifications.head_index = proto.head_index(); + if (proto.has_head_name()) { + classifications.head_name = proto.head_name(); + } + return classifications; +} + +ClassificationResult ConvertToClassificationResult( + const proto::ClassificationResult& proto) { + ClassificationResult classification_result; + classification_result.classifications.reserve(proto.classifications_size()); + for (const auto& classifications : proto.classifications()) { + classification_result.classifications.push_back( + ConvertToClassifications(classifications)); + } + if (proto.has_timestamp_ms()) { + classification_result.timestamp_ms = proto.timestamp_ms(); + } + return classification_result; +} + +} // namespace mediapipe::tasks::components::containers diff --git a/mediapipe/tasks/cc/components/containers/classification_result.h b/mediapipe/tasks/cc/components/containers/classification_result.h new file mode 100644 index 000000000..88273fd00 --- /dev/null +++ b/mediapipe/tasks/cc/components/containers/classification_result.h @@ -0,0 +1,68 @@ +/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ +#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ + +#include +#include +#include + +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace mediapipe::tasks::components::containers { + +// Defines classification results for a given classifier head. +struct Classifications { + // The array of predicted categories, usually sorted by descending scores, + // e.g. from high to low probability. + std::vector categories; + // The index of the classifier head (i.e. output tensor) these categories + // refer to. This is useful for multi-head models. + int head_index; + // The optional name of the classifier head, as provided in the TFLite Model + // Metadata [1] if present. This is useful for multi-head models. + // + // [1]: https://www.tensorflow.org/lite/convert/metadata + std::optional head_name = std::nullopt; +}; + +// Defines classification results of a model. +struct ClassificationResult { + // The classification results for each head of the model. + std::vector classifications; + // The optional timestamp (in milliseconds) of the start of the chunk of data + // corresponding to these results. + // + // This is only used for classification on time series (e.g. audio + // classification). In these use cases, the amount of data to process might + // exceed the maximum size that the model can process: to solve this, the + // input data is split into multiple chunks starting at different timestamps. + std::optional timestamp_ms = std::nullopt; +}; + +// Utility function to convert from Classifications proto to +// Classifications struct. +Classifications ConvertToClassifications(const proto::Classifications& proto); + +// Utility function to convert from ClassificationResult proto to +// ClassificationResult struct. +ClassificationResult ConvertToClassificationResult( + const proto::ClassificationResult& proto); + +} // namespace mediapipe::tasks::components::containers + +#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ diff --git a/mediapipe/tasks/cc/text/text_classifier/BUILD b/mediapipe/tasks/cc/text/text_classifier/BUILD index 336b1bb45..b2e1bed2a 100644 --- a/mediapipe/tasks/cc/text/text_classifier/BUILD +++ b/mediapipe/tasks/cc/text/text_classifier/BUILD @@ -49,6 +49,8 @@ cc_library( ":text_classifier_graph", "//mediapipe/framework:packet", "//mediapipe/framework/api2:builder", + "//mediapipe/tasks/cc/components/containers:category", + "//mediapipe/tasks/cc/components/containers:classification_result", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", @@ -76,7 +78,8 @@ cc_test( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:gtest_main", "//mediapipe/tasks/cc:common", - "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/components/containers:category", + "//mediapipe/tasks/cc/components/containers:classification_result", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc index 699f15bc0..d174fac47 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/packet.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/core/task_api_factory.h" @@ -37,12 +38,13 @@ namespace text_classifier { namespace { +using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr char kTextStreamName[] = "text_in"; constexpr char kTextTag[] = "TEXT"; -constexpr char kClassificationResultStreamName[] = "classification_result_out"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kClassificationsStreamName[] = "classifications_out"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kSubgraphTypeName[] = "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; @@ -54,9 +56,8 @@ CalculatorGraphConfig CreateGraphConfig( auto& subgraph = graph.AddNode(kSubgraphTypeName); subgraph.GetOptions().Swap(options.get()); graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); - subgraph.Out(kClassificationResultTag) - .SetName(kClassificationResultStreamName) >> - graph.Out(kClassificationResultTag); + subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >> + graph.Out(kClassificationsTag); return graph.GetConfig(); } @@ -88,14 +89,14 @@ absl::StatusOr> TextClassifier::Create( std::move(options->base_options.op_resolver)); } -absl::StatusOr TextClassifier::Classify( +absl::StatusOr TextClassifier::Classify( absl::string_view text) { ASSIGN_OR_RETURN( auto output_packets, runner_->Process( {{kTextStreamName, MakePacket(std::string(text))}})); - return output_packets[kClassificationResultStreamName] - .Get(); + return ConvertToClassificationResult( + output_packets[kClassificationsStreamName].Get()); } } // namespace text_classifier diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier.h b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h index b027a9787..03569c5a6 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier.h +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_task_api.h" @@ -31,6 +31,10 @@ namespace tasks { namespace text { namespace text_classifier { +// Alias the shared ClassificationResult struct as result type. +using TextClassifierResult = + ::mediapipe::tasks::components::containers::ClassificationResult; + // The options for configuring a MediaPipe text classifier task. struct TextClassifierOptions { // Base options for configuring MediaPipe Tasks, such as specifying the model @@ -81,8 +85,7 @@ class TextClassifier : core::BaseTaskApi { std::unique_ptr options); // Performs classification on the input `text`. - absl::StatusOr Classify( - absl::string_view text); + absl::StatusOr Classify(absl::string_view text); // Shuts down the TextClassifier when all the work is done. absl::Status Close() { return runner_->Close(); } diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 9706db4d8..36ff68a07 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -47,10 +47,18 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; +// TODO: remove once Java API migration is over. +// Struct holding the different output streams produced by the text classifier. +struct TextClassifierOutputStreams { + Source classification_result; + Source classifications; +}; + } // namespace // A "TextClassifierGraph" performs Natural Language classification (including @@ -62,7 +70,10 @@ constexpr char kTensorsTag[] = "TENSORS"; // Input text to perform classification on. // // Outputs: -// CLASSIFICATION_RESULT - ClassificationResult +// CLASSIFICATIONS - ClassificationResult @Optional +// The classification results aggregated by classifier head. +// TODO: remove once Java API migration is over. +// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional // The aggregated classification result object that has 3 dimensions: // (classification head, classification timestamp, classification category). // @@ -70,7 +81,7 @@ constexpr char kTensorsTag[] = "TENSORS"; // node { // calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph" // input_stream: "TEXT:text_in" -// output_stream: "CLASSIFICATION_RESULT:classification_result_out" +// output_stream: "CLASSIFICATIONS:classifications_out" // options { // [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext] // { @@ -91,12 +102,14 @@ class TextClassifierGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - Source classification_result_out, + auto output_streams, BuildTextClassifierTask( sc->Options(), *model_resources, graph[Input(kTextTag)], graph)); - classification_result_out >> + output_streams.classification_result >> graph[Output(kClassificationResultTag)]; + output_streams.classifications >> + graph[Output(kClassificationsTag)]; return graph.GetConfig(); } @@ -111,7 +124,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // TextClassifier model file with model metadata. // text_in: (std::string) stream to run text classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr> BuildTextClassifierTask( + absl::StatusOr BuildTextClassifierTask( const proto::TextClassifierGraphOptions& task_options, const ModelResources& model_resources, Source text_in, Graph& graph) { @@ -148,8 +161,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. - return postprocessing[Output( - kClassificationResultTag)]; + return TextClassifierOutputStreams{ + /*classification_result=*/postprocessing[Output( + kClassificationResultTag)], + /*classifications=*/postprocessing[Output( + kClassificationsTag)]}; } }; diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc index 62837be8c..8f73914fc 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_test.cc @@ -33,7 +33,8 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h" @@ -43,14 +44,13 @@ namespace text { namespace text_classifier { namespace { -using ::mediapipe::EqualsProto; using ::mediapipe::file::JoinPath; using ::mediapipe::tasks::kMediaPipeTasksPayload; -using ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::tasks::components::containers::Category; +using ::mediapipe::tasks::components::containers::Classifications; using ::testing::HasSubstr; using ::testing::Optional; -constexpr float kEpsilon = 0.001; constexpr int kMaxSeqLen = 128; constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; @@ -64,6 +64,30 @@ std::string GetFullPath(absl::string_view file_name) { return JoinPath("./", kTestDataDirectory, file_name); } +// Checks that the two provided `TextClassifierResult` are equal, with a +// tolerancy on floating-point score to account for numerical instabilities. +// TODO: create shared matcher for ClassificationResult. +void ExpectApproximatelyEqual(const TextClassifierResult& actual, + const TextClassifierResult& expected) { + const float kPrecision = 1e-6; + ASSERT_EQ(actual.classifications.size(), expected.classifications.size()); + for (int i = 0; i < actual.classifications.size(); ++i) { + const Classifications& a = actual.classifications[i]; + const Classifications& b = expected.classifications[i]; + EXPECT_EQ(a.head_index, b.head_index); + EXPECT_EQ(a.head_name, b.head_name); + EXPECT_EQ(a.categories.size(), b.categories.size()); + for (int j = 0; j < a.categories.size(); ++j) { + const Category& x = a.categories[j]; + const Category& y = b.categories[j]; + EXPECT_EQ(x.index, y.index); + EXPECT_NEAR(x.score, y.score, kPrecision); + EXPECT_EQ(x.category_name, y.category_name); + EXPECT_EQ(x.display_name, y.display_name); + } + } +} + class TextClassifierTest : public tflite_shims::testing::Test {}; TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { @@ -107,6 +131,92 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { MP_ASSERT_OK(TextClassifier::Create(std::move(options))); } +TEST_F(TextClassifierTest, TextClassifierWithBert) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN( + TextClassifierResult negative_result, + classifier->Classify("unflinchingly bleak and desperate")); + TextClassifierResult negative_expected; + negative_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"}, + {/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(negative_result, negative_expected); + + MP_ASSERT_OK_AND_ASSIGN( + TextClassifierResult positive_result, + classifier->Classify("it's a charming and often affecting journey")); + TextClassifierResult positive_expected; + positive_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"}, + {/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(positive_result, positive_expected); + + MP_ASSERT_OK(classifier->Close()); +} + +TEST_F(TextClassifierTest, TextClassifierWithIntInputs) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result, + classifier->Classify("What a waste of my time.")); + TextClassifierResult negative_expected; + negative_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"}, + {/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(negative_result, negative_expected); + + MP_ASSERT_OK_AND_ASSIGN( + TextClassifierResult positive_result, + classifier->Classify("This is the best movie I’ve seen in recent years." + "Strongly recommend it!")); + TextClassifierResult positive_expected; + positive_expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"}, + {/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(positive_result, positive_expected); + + MP_ASSERT_OK(classifier->Close()); +} + +TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { + auto options = std::make_unique(); + options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath); + options->base_options.op_resolver = CreateCustomResolver(); + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr classifier, + TextClassifier::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, + classifier->Classify("hello")); + + // Binary outputs causes flaky ordering, so we compare manually. + ASSERT_EQ(result.classifications.size(), 1); + ASSERT_EQ(result.classifications[0].head_index, 0); + ASSERT_EQ(result.classifications[0].categories.size(), 3); + ASSERT_EQ(result.classifications[0].categories[0].score, 1); + ASSERT_LT(result.classifications[0].categories[0].index, 2); // i.e O or 1. + ASSERT_EQ(result.classifications[0].categories[1].score, 1); + ASSERT_LT(result.classifications[0].categories[1].index, 2); // i.e 0 or 1. + ASSERT_EQ(result.classifications[0].categories[2].score, 0); + ASSERT_EQ(result.classifications[0].categories[2].index, 2); + MP_ASSERT_OK(classifier->Close()); +} + } // namespace } // namespace text_classifier } // namespace text diff --git a/mediapipe/tasks/cc/vision/image_classifier/BUILD b/mediapipe/tasks/cc/vision/image_classifier/BUILD index 3d655cd50..b59d8d682 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/BUILD +++ b/mediapipe/tasks/cc/vision/image_classifier/BUILD @@ -50,6 +50,7 @@ cc_library( "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:classification_result", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc index 8a32758f4..60f8f7ed4 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/timestamp.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" @@ -46,8 +47,8 @@ namespace image_classifier { namespace { -constexpr char kClassificationResultStreamName[] = "classification_result_out"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kClassificationsStreamName[] = "classifications_out"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageInStreamName[] = "image_in"; constexpr char kImageOutStreamName[] = "image_out"; constexpr char kImageTag[] = "IMAGE"; @@ -57,6 +58,7 @@ constexpr char kSubgraphTypeName[] = "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; constexpr int kMicroSecondsPerMilliSecond = 1000; +using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::PacketMap; @@ -73,15 +75,13 @@ CalculatorGraphConfig CreateGraphConfig( auto& task_subgraph = graph.AddNode(kSubgraphTypeName); task_subgraph.GetOptions().Swap( options_proto.get()); - task_subgraph.Out(kClassificationResultTag) - .SetName(kClassificationResultStreamName) >> - graph.Out(kClassificationResultTag); + task_subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >> + graph.Out(kClassificationsTag); task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); if (enable_flow_limiting) { - return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, - {kImageTag, kNormRectTag}, - kClassificationResultTag); + return tasks::core::AddFlowLimiterCalculator( + graph, task_subgraph, {kImageTag, kNormRectTag}, kClassificationsTag); } graph.In(kImageTag) >> task_subgraph.In(kImageTag); graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); @@ -125,13 +125,14 @@ absl::StatusOr> ImageClassifier::Create( if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { return; } - Packet classification_result_packet = - status_or_packets.value()[kClassificationResultStreamName]; + Packet classifications_packet = + status_or_packets.value()[kClassificationsStreamName]; Packet image_packet = status_or_packets.value()[kImageOutStreamName]; result_callback( - classification_result_packet.Get(), + ConvertToClassificationResult( + classifications_packet.Get()), image_packet.Get(), - classification_result_packet.Timestamp().Value() / + classifications_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); }; } @@ -144,7 +145,7 @@ absl::StatusOr> ImageClassifier::Create( std::move(packets_callback)); } -absl::StatusOr ImageClassifier::Classify( +absl::StatusOr ImageClassifier::Classify( Image image, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -160,11 +161,11 @@ absl::StatusOr ImageClassifier::Classify( ProcessImageData( {{kImageInStreamName, MakePacket(std::move(image))}, {kNormRectName, MakePacket(std::move(norm_rect))}})); - return output_packets[kClassificationResultStreamName] - .Get(); + return ConvertToClassificationResult( + output_packets[kClassificationsStreamName].Get()); } -absl::StatusOr ImageClassifier::ClassifyForVideo( +absl::StatusOr ImageClassifier::ClassifyForVideo( Image image, int64 timestamp_ms, std::optional image_processing_options) { if (image.UsesGpu()) { @@ -184,8 +185,8 @@ absl::StatusOr ImageClassifier::ClassifyForVideo( {kNormRectName, MakePacket(std::move(norm_rect)) .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); - return output_packets[kClassificationResultStreamName] - .Get(); + return ConvertToClassificationResult( + output_packets[kClassificationsStreamName].Get()); } absl::Status ImageClassifier::ClassifyAsync( diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h index de69b7994..9b0c376ae 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mediapipe/framework/formats/image.h" -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h" #include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" @@ -34,6 +34,10 @@ namespace tasks { namespace vision { namespace image_classifier { +// Alias the shared ClassificationResult struct as result type. +using ImageClassifierResult = + ::mediapipe::tasks::components::containers::ClassificationResult; + // The options for configuring a Mediapipe image classifier task. struct ImageClassifierOptions { // Base options for configuring MediaPipe Tasks, such as specifying the model @@ -56,9 +60,8 @@ struct ImageClassifierOptions { // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. - std::function, - const Image&, int64)> + std::function, const Image&, + int64)> result_callback = nullptr; }; @@ -122,7 +125,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. // TODO: describe exact preprocessing steps once // YUVToImageCalculator is integrated. - absl::StatusOr Classify( + absl::StatusOr Classify( mediapipe::Image image, std::optional image_processing_options = std::nullopt); @@ -144,10 +147,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - absl::StatusOr - ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, - std::optional - image_processing_options = std::nullopt); + absl::StatusOr ClassifyForVideo( + mediapipe::Image image, int64 timestamp_ms, + std::optional image_processing_options = + std::nullopt); // Sends live image data to image classification, and the results will be // available via the "result_callback" provided in the ImageClassifierOptions. @@ -170,7 +173,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { // increasing. // // The "result_callback" provides: - // - The classification results as a ClassificationResult object. + // - The classification results as an ImageClassifierResult object. // - The const reference to the corresponding input image that the image // classifier runs on. Note that the const reference to the image will no // longer be valid when the callback returns. To access the image data diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8a1b17ce9..8fa1a0d2a 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -48,6 +48,7 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; +constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; @@ -56,6 +57,7 @@ constexpr char kTensorsTag[] = "TENSORS"; // subgraph. struct ImageClassifierOutputStreams { Source classification_result; + Source classifications; Source image; }; @@ -71,17 +73,19 @@ struct ImageClassifierOutputStreams { // Describes region of image to perform classification on. // @Optional: rect covering the whole image is used if not specified. // Outputs: -// CLASSIFICATION_RESULT - ClassificationResult -// The aggregated classification result object has two dimensions: -// (classification head, classification category) +// CLASSIFICATIONS - ClassificationResult @Optional +// The classification results aggregated by classifier head. // IMAGE - Image // The image that object detection runs on. +// TODO: remove this output once Java API migration is over. +// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional +// The aggregated classification result. // // Example: // node { // calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph" // input_stream: "IMAGE:image_in" -// output_stream: "CLASSIFICATION_RESULT:classification_result_out" +// output_stream: "CLASSIFICATIONS:classifications_out" // output_stream: "IMAGE:image_out" // options { // [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext] @@ -115,6 +119,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph { graph[Input::Optional(kNormRectTag)], graph)); output_streams.classification_result >> graph[Output(kClassificationResultTag)]; + output_streams.classifications >> + graph[Output(kClassificationsTag)]; output_streams.image >> graph[Output(kImageTag)]; return graph.GetConfig(); } @@ -170,6 +176,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph { return ImageClassifierOutputStreams{ /*classification_result=*/postprocessing[Output( kClassificationResultTag)], + /*classifications=*/ + postprocessing[Output(kClassificationsTag)], /*image=*/preprocessing[Output(kImageTag)]}; } }; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc index 0c45122c0..1144e9032 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_test.cc @@ -32,8 +32,8 @@ limitations under the License. #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/tasks/cc/common.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" #include "mediapipe/tasks/cc/components/containers/rect.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" @@ -50,10 +50,9 @@ namespace image_classifier { namespace { using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::Category; +using ::mediapipe::tasks::components::containers::Classifications; using ::mediapipe::tasks::components::containers::Rect; -using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; -using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::components::containers::proto::Classifications; using ::mediapipe::tasks::vision::core::ImageProcessingOptions; using ::testing::HasSubstr; using ::testing::Optional; @@ -65,83 +64,56 @@ constexpr char kMobileNetQuantizedWithMetadata[] = constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] = "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite"; -// Checks that the two provided `ClassificationResult` are equal, with a +// Checks that the two provided `ImageClassifierResult` are equal, with a // tolerancy on floating-point score to account for numerical instabilities. -void ExpectApproximatelyEqual(const ClassificationResult& actual, - const ClassificationResult& expected) { +void ExpectApproximatelyEqual(const ImageClassifierResult& actual, + const ImageClassifierResult& expected) { const float kPrecision = 1e-6; - ASSERT_EQ(actual.classifications_size(), expected.classifications_size()); - for (int i = 0; i < actual.classifications_size(); ++i) { - const Classifications& a = actual.classifications(i); - const Classifications& b = expected.classifications(i); - EXPECT_EQ(a.head_index(), b.head_index()); - EXPECT_EQ(a.head_name(), b.head_name()); - EXPECT_EQ(a.entries_size(), b.entries_size()); - for (int j = 0; j < a.entries_size(); ++j) { - const ClassificationEntry& x = a.entries(j); - const ClassificationEntry& y = b.entries(j); - EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms()); - EXPECT_EQ(x.categories_size(), y.categories_size()); - for (int k = 0; k < x.categories_size(); ++k) { - EXPECT_EQ(x.categories(k).index(), y.categories(k).index()); - EXPECT_EQ(x.categories(k).category_name(), - y.categories(k).category_name()); - EXPECT_EQ(x.categories(k).display_name(), - y.categories(k).display_name()); - EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(), - kPrecision); - } + ASSERT_EQ(actual.classifications.size(), expected.classifications.size()); + for (int i = 0; i < actual.classifications.size(); ++i) { + const Classifications& a = actual.classifications[i]; + const Classifications& b = expected.classifications[i]; + EXPECT_EQ(a.head_index, b.head_index); + EXPECT_EQ(a.head_name, b.head_name); + EXPECT_EQ(a.categories.size(), b.categories.size()); + for (int j = 0; j < a.categories.size(); ++j) { + const Category& x = a.categories[j]; + const Category& y = b.categories[j]; + EXPECT_EQ(x.index, y.index); + EXPECT_NEAR(x.score, y.score, kPrecision); + EXPECT_EQ(x.category_name, y.category_name); + EXPECT_EQ(x.display_name, y.display_name); } } } // Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata // with max_results set to 3. -ClassificationResult GenerateBurgerResults(int64 timestamp) { - return ParseTextProtoOrDie( - absl::StrFormat(R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 932 - score: 0.027392805 - category_name: "bagel" - } - categories { - index: 925 - score: 0.019340655 - category_name: "guacamole" - } - timestamp_ms: %d - } - head_index: 0 - head_name: "probability" - })pb", - timestamp)); +ImageClassifierResult GenerateBurgerResults() { + ImageClassifierResult result; + result.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/934, /*score=*/0.793959200, + /*category_name=*/"cheeseburger"}, + {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}, + {/*index=*/925, /*score=*/0.019340655, + /*category_name=*/"guacamole"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + return result; } // Generates expected results for "multi_objects.jpg" using // kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding // box set around the soccer ball. -ClassificationResult GenerateSoccerBallResults(int64 timestamp) { - return ParseTextProtoOrDie( - absl::StrFormat(R"pb(classifications { - entries { - categories { - index: 806 - score: 0.996527493 - category_name: "soccer ball" - } - timestamp_ms: %d - } - head_index: 0 - head_name: "probability" - })pb", - timestamp)); +ImageClassifierResult GenerateSoccerBallResults() { + ImageClassifierResult result; + result.classifications.emplace_back( + Classifications{/*categories=*/{{/*index=*/806, /*score=*/0.996527493, + /*category_name=*/"soccer ball"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + return result; } // A custom OpResolver only containing the Ops required by the test model. @@ -260,7 +232,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); options->running_mode = running_mode; - options->result_callback = [](absl::StatusOr, + options->result_callback = [](absl::StatusOr, const Image& image, int64 timestamp_ms) {}; auto image_classifier = ImageClassifier::Create(std::move(options)); @@ -336,7 +308,7 @@ TEST_F(ImageModeTest, SucceedsWithFloatModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - ExpectApproximatelyEqual(results, GenerateBurgerResults(0)); + ExpectApproximatelyEqual(results, GenerateBurgerResults()); } TEST_F(ImageModeTest, SucceedsWithQuantizedModel) { @@ -355,19 +327,13 @@ TEST_F(ImageModeTest, SucceedsWithQuantizedModel) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.97265625 - category_name: "cheeseburger" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back( + Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.97265625, + /*category_name=*/"cheeseburger"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -383,19 +349,13 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back( + Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.7939592, + /*category_name=*/"cheeseburger"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { @@ -411,24 +371,15 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 932 - score: 0.027392805 - category_name: "bagel" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/934, /*score=*/0.7939592, + /*category_name=*/"cheeseburger"}, + {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { @@ -445,29 +396,17 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 925 - score: 0.019340655 - category_name: "guacamole" - } - categories { - index: 963 - score: 0.0063278517 - category_name: "meat loaf" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/934, /*score=*/0.7939592, + /*category_name=*/"cheeseburger"}, + {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"}, + {/*index=*/963, /*score=*/0.0063278517, + /*category_name=*/"meat loaf"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } TEST_F(ImageModeTest, SucceedsWithDenylistOption) { @@ -484,29 +423,17 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.7939592 - category_name: "cheeseburger" - } - categories { - index: 925 - score: 0.019340655 - category_name: "guacamole" - } - categories { - index: 963 - score: 0.0063278517 - category_name: "meat loaf" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/934, /*score=*/0.7939592, + /*category_name=*/"cheeseburger"}, + {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"}, + {/*index=*/963, /*score=*/0.0063278517, + /*category_name=*/"meat loaf"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { @@ -525,19 +452,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.725648628 - category_name: "cheeseburger" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back( + Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.725648628, + /*category_name=*/"cheeseburger"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { @@ -557,7 +478,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); - ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); + ExpectApproximatelyEqual(results, GenerateSoccerBallResults()); } TEST_F(ImageModeTest, SucceedsWithRotation) { @@ -581,29 +502,17 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { // Results differ slightly from the non-rotated image, but that's expected // as models are very sensitive to the slightest numerical differences // introduced by the rotation and JPG encoding. - ExpectApproximatelyEqual(results, ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 934 - score: 0.6371766 - category_name: "cheeseburger" - } - categories { - index: 963 - score: 0.049443405 - category_name: "meat loaf" - } - categories { - index: 925 - score: 0.047918003 - category_name: "guacamole" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back(Classifications{ + /*categories=*/{ + {/*index=*/934, /*score=*/0.6371766, + /*category_name=*/"cheeseburger"}, + {/*index=*/963, /*score=*/0.049443405, /*category_name=*/"meat loaf"}, + {/*index=*/925, /*score=*/0.047918003, + /*category_name=*/"guacamole"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { @@ -624,20 +533,13 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( image, image_processing_options)); - ExpectApproximatelyEqual(results, - ParseTextProtoOrDie( - R"pb(classifications { - entries { - categories { - index: 560 - score: 0.6522213 - category_name: "folding chair" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); + ImageClassifierResult expected; + expected.classifications.emplace_back( + Classifications{/*categories=*/{{/*index=*/560, /*score=*/0.6522213, + /*category_name=*/"folding chair"}}, + /*head_index=*/0, + /*head_name=*/"probability"}); + ExpectApproximatelyEqual(results, expected); } // Testing all these once with ImageClassifier. @@ -774,7 +676,7 @@ TEST_F(VideoModeTest, Succeeds) { for (int i = 0; i < iterations; ++i) { MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->ClassifyForVideo(image, i)); - ExpectApproximatelyEqual(results, GenerateBurgerResults(i)); + ExpectApproximatelyEqual(results, GenerateBurgerResults()); } MP_ASSERT_OK(image_classifier->Close()); } @@ -800,7 +702,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { MP_ASSERT_OK_AND_ASSIGN( auto results, image_classifier->ClassifyForVideo(image, i, image_processing_options)); - ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); + ExpectApproximatelyEqual(results, GenerateSoccerBallResults()); } MP_ASSERT_OK(image_classifier->Close()); } @@ -815,7 +717,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = [](absl::StatusOr, + options->result_callback = [](absl::StatusOr, const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); @@ -846,7 +748,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; - options->result_callback = [](absl::StatusOr, + options->result_callback = [](absl::StatusOr, const Image& image, int64 timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr image_classifier, ImageClassifier::Create(std::move(options))); @@ -864,7 +766,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { } struct LiveStreamModeResults { - ClassificationResult classification_result; + ImageClassifierResult classification_result; std::pair image_size; int64 timestamp_ms; }; @@ -881,7 +783,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { options->running_mode = core::RunningMode::LIVE_STREAM; options->classifier_options.max_results = 3; options->result_callback = - [&results](absl::StatusOr classification_result, + [&results](absl::StatusOr classification_result, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(classification_result.status()); results.push_back( @@ -908,7 +810,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { EXPECT_EQ(result.image_size.first, image.width()); EXPECT_EQ(result.image_size.second, image.height()); ExpectApproximatelyEqual(result.classification_result, - GenerateBurgerResults(timestamp_ms)); + GenerateBurgerResults()); } } @@ -924,7 +826,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { options->running_mode = core::RunningMode::LIVE_STREAM; options->classifier_options.max_results = 1; options->result_callback = - [&results](absl::StatusOr classification_result, + [&results](absl::StatusOr classification_result, const Image& image, int64 timestamp_ms) { MP_ASSERT_OK(classification_result.status()); results.push_back( @@ -955,7 +857,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { EXPECT_EQ(result.image_size.first, image.width()); EXPECT_EQ(result.image_size.second, image.height()); ExpectApproximatelyEqual(result.classification_result, - GenerateSoccerBallResults(timestamp_ms)); + GenerateSoccerBallResults()); } }