From 7acbf557a1294e3809e8671ac769c855dd3336c4 Mon Sep 17 00:00:00 2001 From: MediaPipe Team Date: Mon, 21 Nov 2022 01:55:49 -0800 Subject: [PATCH] Cleanup after migration to new classification output format. PiperOrigin-RevId: 489921603 --- .../tasks/cc/components/calculators/BUILD | 1 - .../classification_aggregation_calculator.cc | 68 +--- .../cc/components/containers/proto/BUILD | 6 - .../containers/proto/category.proto | 41 --- .../containers/proto/classifications.proto | 17 +- .../classification_postprocessing_graph.cc | 9 - .../classification_postprocessing_graph.h | 3 - ...lassification_postprocessing_graph_test.cc | 322 ------------------ .../text_classifier/text_classifier_graph.cc | 27 +- .../image_classifier_graph.cc | 9 - .../com/google/mediapipe/tasks/text/BUILD | 1 - .../com/google/mediapipe/tasks/vision/BUILD | 1 - .../tasks/python/components/containers/BUILD | 2 +- .../python/components/containers/category.py | 16 +- .../containers/classification_result.py | 15 +- 15 files changed, 23 insertions(+), 515 deletions(-) delete mode 100644 mediapipe/tasks/cc/components/containers/proto/category.proto diff --git a/mediapipe/tasks/cc/components/calculators/BUILD b/mediapipe/tasks/cc/components/calculators/BUILD index 1f726a018..16931811c 100644 --- a/mediapipe/tasks/cc/components/calculators/BUILD +++ b/mediapipe/tasks/cc/components/calculators/BUILD @@ -37,7 +37,6 @@ cc_library( "//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", - "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", "@com_google_absl//absl/status", ], diff --git a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc index 1a83fdad2..ad2c668c3 100644 --- a/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc +++ b/mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.cc @@ -25,14 +25,12 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" -#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace mediapipe { namespace api2 { using ::mediapipe::tasks::components::containers::proto::ClassificationResult; -using ::mediapipe::tasks::components::containers::proto::Classifications; // Aggregates ClassificationLists into either a ClassificationResult object // representing the classification results aggregated by classifier head, or @@ -57,9 +55,6 @@ using ::mediapipe::tasks::components::containers::proto::Classifications; // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example without timestamp aggregation: // node { @@ -122,9 +117,6 @@ class ClassificationAggregationCalculator : public Node { ClassificationResult ConvertToClassificationResult(CalculatorContext* cc); std::vector ConvertToTimestampedClassificationResults( CalculatorContext* cc); - // TODO: deprecate this function once migration is over. - ClassificationResult LegacyConvertToClassificationResult( - CalculatorContext* cc); }; absl::Status ClassificationAggregationCalculator::UpdateContract( @@ -137,10 +129,11 @@ absl::Status ClassificationAggregationCalculator::UpdateContract( << "The size of classifications input streams should match the " "size of head names specified in the calculator options"; } - // TODO: enforce connecting TIMESTAMPED_CLASSIFICATIONS if - // TIMESTAMPS is connected, and connecting CLASSIFICATIONS if TIMESTAMPS is - // not connected. All dependent tasks must be updated to use these outputs - // first. + if (kTimestampsIn(cc).IsConnected()) { + RET_CHECK(kTimestampedClassificationsOut(cc).IsConnected()); + } else { + RET_CHECK(kClassificationsOut(cc).IsConnected()); + } return absl::OkStatus(); } @@ -170,11 +163,9 @@ absl::Status ClassificationAggregationCalculator::Process( if (kTimestampsIn(cc).IsEmpty()) { return absl::OkStatus(); } - classification_result = LegacyConvertToClassificationResult(cc); kTimestampedClassificationsOut(cc).Send( ConvertToTimestampedClassificationResults(cc)); } else { - classification_result = LegacyConvertToClassificationResult(cc); kClassificationsOut(cc).Send(ConvertToClassificationResult(cc)); } kClassificationResultOut(cc).Send(classification_result); @@ -226,55 +217,6 @@ ClassificationAggregationCalculator::ConvertToTimestampedClassificationResults( return results; } -ClassificationResult -ClassificationAggregationCalculator::LegacyConvertToClassificationResult( - CalculatorContext* cc) { - ClassificationResult result; - Timestamp first_timestamp(0); - std::vector timestamps; - if (time_aggregation_enabled_) { - timestamps = kTimestampsIn(cc).Get(); - first_timestamp = timestamps[0]; - } else { - timestamps = {cc->InputTimestamp()}; - } - for (Timestamp timestamp : timestamps) { - int count = cached_classifications_[timestamp.Value()].size(); - for (int i = 0; i < count; ++i) { - Classifications* c; - if (result.classifications_size() <= i) { - c = result.add_classifications(); - if (!head_names_.empty()) { - c->set_head_index(i); - c->set_head_name(head_names_[i]); - } - } else { - c = result.mutable_classifications(i); - } - auto* entry = c->add_entries(); - for (const auto& elem : - cached_classifications_[timestamp.Value()][i].classification()) { - auto* category = entry->add_categories(); - if (elem.has_index()) { - category->set_index(elem.index()); - } - if (elem.has_score()) { - category->set_score(elem.score()); - } - if (elem.has_label()) { - category->set_category_name(elem.label()); - } - if (elem.has_display_name()) { - category->set_display_name(elem.display_name()); - } - } - entry->set_timestamp_ms((timestamp.Value() - first_timestamp.Value()) / - 1000); - } - } - return result; -} - MEDIAPIPE_REGISTER_NODE(ClassificationAggregationCalculator); } // namespace api2 diff --git a/mediapipe/tasks/cc/components/containers/proto/BUILD b/mediapipe/tasks/cc/components/containers/proto/BUILD index 7b455c0c4..27d2357b5 100644 --- a/mediapipe/tasks/cc/components/containers/proto/BUILD +++ b/mediapipe/tasks/cc/components/containers/proto/BUILD @@ -18,16 +18,10 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) licenses(["notice"]) -mediapipe_proto_library( - name = "category_proto", - srcs = ["category.proto"], -) - mediapipe_proto_library( name = "classifications_proto", srcs = ["classifications.proto"], deps = [ - ":category_proto", "//mediapipe/framework/formats:classification_proto", ], ) diff --git a/mediapipe/tasks/cc/components/containers/proto/category.proto b/mediapipe/tasks/cc/components/containers/proto/category.proto deleted file mode 100644 index 412e71428..000000000 --- a/mediapipe/tasks/cc/components/containers/proto/category.proto +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2022 The MediaPipe Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -syntax = "proto2"; - -package mediapipe.tasks.components.containers.proto; - -option java_package = "com.google.mediapipe.tasks.components.containers.proto"; -option java_outer_classname = "CategoryProto"; - -// TODO: deprecate this message once migration is over. -// A single classification result. -message Category { - // The index of the category in the corresponding label map, usually packed in - // the TFLite Model Metadata [1]. - // - // [1]: https://www.tensorflow.org/lite/convert/metadata - optional int32 index = 1; - // The score for this category, e.g. (but not necessarily) a probability in - // [0,1]. - optional float score = 2; - // A human readable name of the category filled from the label map. - optional string display_name = 3; - // An ID for the category, not necessarily human-readable, e.g. a Google - // Knowledge Graph ID [1], filled from the label map. - // - // [1]: https://developers.google.com/knowledge-graph - optional string category_name = 4; -} diff --git a/mediapipe/tasks/cc/components/containers/proto/classifications.proto b/mediapipe/tasks/cc/components/containers/proto/classifications.proto index f098ed0e4..2b2306829 100644 --- a/mediapipe/tasks/cc/components/containers/proto/classifications.proto +++ b/mediapipe/tasks/cc/components/containers/proto/classifications.proto @@ -18,27 +18,12 @@ syntax = "proto2"; package mediapipe.tasks.components.containers.proto; import "mediapipe/framework/formats/classification.proto"; -import "mediapipe/tasks/cc/components/containers/proto/category.proto"; option java_package = "com.google.mediapipe.tasks.components.containers.proto"; option java_outer_classname = "ClassificationsProto"; -// TODO: deprecate this message once migration is over. -// List of predicted categories with an optional timestamp. -message ClassificationEntry { - // The array of predicted categories, usually sorted by descending scores, - // e.g., from high to low probability. - repeated Category categories = 1; - // The optional timestamp (in milliseconds) associated to the classifcation - // entry. This is useful for time series use cases, e.g., audio - // classification. - optional int64 timestamp_ms = 2; -} - // Classifications for a given classifier head, i.e. for a given output tensor. message Classifications { - // TODO: deprecate this field once migration is over. - repeated ClassificationEntry entries = 1; // The classification results for this head. optional mediapipe.ClassificationList classification_list = 4; // The index of the classifier head these categories refer to. This is useful @@ -48,6 +33,8 @@ message Classifications { // name. // TODO: Add github link to metadata_schema.fbs. optional string head_name = 3; + // Reserved fields. + reserved 1; } // Classifications for a given classifier model. diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc index 0fb62afaf..5a0472f5c 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.cc @@ -73,7 +73,6 @@ using TensorsSource = mediapipe::tasks::SourceOrNodeOutput>; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kScoresTag[] = "SCORES"; constexpr char kTensorsTag[] = "TENSORS"; @@ -82,7 +81,6 @@ constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; // Struct holding the different output streams produced by the graph. struct ClassificationPostprocessingOutputStreams { - Source classification_result; Source classifications; Source> timestamped_classifications; }; @@ -400,9 +398,6 @@ absl::Status ConfigureClassificationPostprocessingGraph( // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // The recommended way of using this graph is through the GraphBuilder API // using the 'ConfigureClassificationPostprocessingGraph()' function. See header @@ -418,8 +413,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { sc->Options(), graph[Input>(kTensorsTag)], graph[Input>(kTimestampsTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.timestamped_classifications >> @@ -536,8 +529,6 @@ class ClassificationPostprocessingGraph : public mediapipe::Subgraph { // Connects output. ClassificationPostprocessingOutputStreams output_streams{ - /*classification_result=*/result_aggregation - [Output(kClassificationResultTag)], /*classifications=*/ result_aggregation[Output(kClassificationsTag)], /*timestamped_classifications=*/ diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h index 48575ceb0..03ae91130 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h @@ -58,9 +58,6 @@ namespace processors { // The classification result aggregated by timestamp, then by head. Must be // connected if the TIMESTAMPS input is connected, as it signals that // timestamp aggregation is required. -// // TODO: remove output once migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. absl::Status ConfigureClassificationPostprocessingGraph( const tasks::core::ModelResources& model_resources, const proto::ClassifierOptions& classifier_options, diff --git a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc index d4728e725..8eb6f3c3b 100644 --- a/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc +++ b/mediapipe/tasks/cc/components/processors/classification_postprocessing_graph_test.cc @@ -86,8 +86,6 @@ constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsName[] = "tensors"; constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsName[] = "timestamps"; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; -constexpr char kClassificationResultName[] = "classification_result"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kClassificationsName[] = "classifications"; constexpr char kTimestampedClassificationsTag[] = "TIMESTAMPED_CLASSIFICATIONS"; @@ -728,326 +726,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { })pb")})); } -// TODO: remove these tests once migration is over. -class LegacyPostprocessingTest : public tflite_shims::testing::Test { - protected: - absl::StatusOr BuildGraph( - absl::string_view model_name, const proto::ClassifierOptions& options, - bool connect_timestamps = false) { - ASSIGN_OR_RETURN(auto model_resources, - CreateModelResourcesForModel(model_name)); - - Graph graph; - auto& postprocessing = graph.AddNode( - "mediapipe.tasks.components.processors." - "ClassificationPostprocessingGraph"); - MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph( - *model_resources, options, - &postprocessing - .GetOptions())); - graph[Input>(kTensorsTag)].SetName(kTensorsName) >> - postprocessing.In(kTensorsTag); - if (connect_timestamps) { - graph[Input>(kTimestampsTag)].SetName( - kTimestampsName) >> - postprocessing.In(kTimestampsTag); - } - postprocessing.Out(kClassificationResultTag) - .SetName(kClassificationResultName) >> - graph[Output(kClassificationResultTag)]; - - MP_RETURN_IF_ERROR(calculator_graph_.Initialize(graph.GetConfig())); - ASSIGN_OR_RETURN(auto poller, calculator_graph_.AddOutputStreamPoller( - kClassificationResultName)); - MP_RETURN_IF_ERROR(calculator_graph_.StartRun(/*extra_side_packets=*/{})); - return poller; - } - - template - void AddTensor( - const std::vector& tensor, const Tensor::ElementType& element_type, - const Tensor::QuantizationParameters& quantization_parameters = {}) { - tensors_->emplace_back(element_type, - Tensor::Shape{1, static_cast(tensor.size())}, - quantization_parameters); - auto view = tensors_->back().GetCpuWriteView(); - T* buffer = view.buffer(); - std::copy(tensor.begin(), tensor.end(), buffer); - } - - absl::Status Run( - std::optional> aggregation_timestamps = std::nullopt, - int timestamp = 0) { - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTensorsName, Adopt(tensors_.release()).At(Timestamp(timestamp)))); - // Reset tensors for future calls. - tensors_ = absl::make_unique>(); - if (aggregation_timestamps.has_value()) { - auto packet = absl::make_unique>(); - for (const auto& timestamp : *aggregation_timestamps) { - packet->emplace_back(Timestamp(timestamp)); - } - MP_RETURN_IF_ERROR(calculator_graph_.AddPacketToInputStream( - kTimestampsName, Adopt(packet.release()).At(Timestamp(timestamp)))); - } - return absl::OkStatus(); - } - - absl::StatusOr GetClassificationResult( - OutputStreamPoller& poller) { - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilIdle()); - MP_RETURN_IF_ERROR(calculator_graph_.CloseAllInputStreams()); - - Packet packet; - if (!poller.Next(&packet)) { - return absl::InternalError("Unable to get output packet"); - } - auto result = packet.Get(); - MP_RETURN_IF_ERROR(calculator_graph_.WaitUntilDone()); - return result; - } - - private: - CalculatorGraph calculator_graph_; - std::unique_ptr> tensors_ = - absl::make_unique>(); -}; - -TEST_F(LegacyPostprocessingTest, SucceedsWithoutMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - options.set_score_threshold(0.5); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithoutMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 18; - tensor[2] = 16; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto(R"pb(classifications { - entries { - categories { index: 1 score: 0.8 } - categories { index: 2 score: 0.6 } - timestamp_ms: 0 - } - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMetadata) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.8 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithScoreCalibration) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(3); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); - // Build input tensors. - std::vector tensor(kMobileNetNumClasses, 0); - tensor[1] = 12; - tensor[2] = 14; - tensor[3] = 16; - tensor[4] = 18; - - // Send tensors and get results. - AddTensor(tensor, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 4 - score: 0.6899744811 - category_name: "tiger shark" - } - categories { - index: 3 - score: 0.6456563062 - category_name: "great white shark" - } - categories { - index: 2 - score: 0.5986876601 - category_name: "goldfish" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "probability" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithMultipleHeads) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, - BuildGraph(kFloatTwoHeadsAudioClassifierWithMetadata, options)); - // Build input tensors. - std::vector tensor_0(kTwoHeadsNumClasses[0], 0); - tensor_0[1] = 0.2; - tensor_0[2] = 0.4; - tensor_0[3] = 0.6; - std::vector tensor_1(kTwoHeadsNumClasses[1], 0); - tensor_1[1] = 0.2; - tensor_1[2] = 0.4; - tensor_1[3] = 0.6; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kFloat32); - AddTensor(tensor_1, Tensor::ElementType::kFloat32); - MP_ASSERT_OK(Run()); - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - EXPECT_THAT(results, EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Narration, monologue" - } - categories { - index: 2 - score: 0.4 - category_name: "Conversation" - } - timestamp_ms: 0 - } - head_index: 0 - head_name: "yamnet_classification" - } - classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "Azara\'s Spinetail" - } - categories { - index: 2 - score: 0.4 - category_name: "House Sparrow" - } - timestamp_ms: 0 - } - head_index: 1 - head_name: "bird_classification" - })pb")); -} - -TEST_F(LegacyPostprocessingTest, SucceedsWithTimestamps) { - // Build graph. - proto::ClassifierOptions options; - options.set_max_results(2); - MP_ASSERT_OK_AND_ASSIGN( - auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, - /*connect_timestamps=*/true)); - // Build input tensors. - std::vector tensor_0(kMobileNetNumClasses, 0); - tensor_0[1] = 12; - tensor_0[2] = 14; - tensor_0[3] = 16; - std::vector tensor_1(kMobileNetNumClasses, 0); - tensor_1[5] = 12; - tensor_1[6] = 14; - tensor_1[7] = 16; - - // Send tensors and get results. - AddTensor(tensor_0, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run()); - AddTensor(tensor_1, Tensor::ElementType::kUInt8, - /*quantization_parameters=*/{0.1, 10}); - MP_ASSERT_OK(Run( - /*aggregation_timestamps=*/std::optional>({0, 1000}), - /*timestamp=*/1000)); - - MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); - - // Validate results. - EXPECT_THAT( - results, - EqualsProto( - R"pb(classifications { - entries { - categories { - index: 3 - score: 0.6 - category_name: "great white shark" - } - categories { index: 2 score: 0.4 category_name: "goldfish" } - timestamp_ms: 0 - } - entries { - categories { index: 7 score: 0.6 category_name: "stingray" } - categories { - index: 6 - score: 0.4 - category_name: "electric ray" - } - timestamp_ms: 1 - } - head_index: 0 - head_name: "probability" - })pb")); -} - } // namespace } // namespace processors } // namespace components diff --git a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc index 36ff68a07..9a7dce1aa 100644 --- a/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc +++ b/mediapipe/tasks/cc/text/text_classifier/text_classifier_graph.cc @@ -46,19 +46,11 @@ using ::mediapipe::api2::builder::Source; using ::mediapipe::tasks::components::containers::proto::ClassificationResult; using ::mediapipe::tasks::core::ModelResources; -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kTextTag[] = "TEXT"; constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; constexpr char kTensorsTag[] = "TENSORS"; -// TODO: remove once Java API migration is over. -// Struct holding the different output streams produced by the text classifier. -struct TextClassifierOutputStreams { - Source classification_result; - Source classifications; -}; - } // namespace // A "TextClassifierGraph" performs Natural Language classification (including @@ -72,10 +64,6 @@ struct TextClassifierOutputStreams { // Outputs: // CLASSIFICATIONS - ClassificationResult @Optional // The classification results aggregated by classifier head. -// TODO: remove once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result object that has 3 dimensions: -// (classification head, classification timestamp, classification category). // // Example: // node { @@ -102,14 +90,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { CreateModelResources(sc)); Graph graph; ASSIGN_OR_RETURN( - auto output_streams, + auto classifications, BuildTextClassifierTask( sc->Options(), *model_resources, graph[Input(kTextTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; - output_streams.classifications >> - graph[Output(kClassificationsTag)]; + classifications >> graph[Output(kClassificationsTag)]; return graph.GetConfig(); } @@ -124,7 +109,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // TextClassifier model file with model metadata. // text_in: (std::string) stream to run text classification on. // graph: the mediapipe builder::Graph instance to be updated. - absl::StatusOr BuildTextClassifierTask( + absl::StatusOr> BuildTextClassifierTask( const proto::TextClassifierGraphOptions& task_options, const ModelResources& model_resources, Source text_in, Graph& graph) { @@ -161,11 +146,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. - return TextClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], - /*classifications=*/postprocessing[Output( - kClassificationsTag)]}; + return postprocessing[Output(kClassificationsTag)]; } }; diff --git a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc index 8fa1a0d2a..2fc88bcb6 100644 --- a/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc +++ b/mediapipe/tasks/cc/vision/image_classifier/image_classifier_graph.cc @@ -47,7 +47,6 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; constexpr float kDefaultScoreThreshold = std::numeric_limits::lowest(); -constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; constexpr char kImageTag[] = "IMAGE"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -56,7 +55,6 @@ constexpr char kTensorsTag[] = "TENSORS"; // Struct holding the different output streams produced by the image classifier // subgraph. struct ImageClassifierOutputStreams { - Source classification_result; Source classifications; Source image; }; @@ -77,9 +75,6 @@ struct ImageClassifierOutputStreams { // The classification results aggregated by classifier head. // IMAGE - Image // The image that object detection runs on. -// TODO: remove this output once Java API migration is over. -// CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional -// The aggregated classification result. // // Example: // node { @@ -117,8 +112,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { sc->Options(), *model_resources, graph[Input(kImageTag)], graph[Input::Optional(kNormRectTag)], graph)); - output_streams.classification_result >> - graph[Output(kClassificationResultTag)]; output_streams.classifications >> graph[Output(kClassificationsTag)]; output_streams.image >> graph[Output(kImageTag)]; @@ -174,8 +167,6 @@ class ImageClassifierGraph : public core::ModelTaskGraph { // Outputs the aggregated classification result as the subgraph output // stream. return ImageClassifierOutputStreams{ - /*classification_result=*/postprocessing[Output( - kClassificationResultTag)], /*classifications=*/ postprocessing[Output(kClassificationsTag)], /*image=*/preprocessing[Output(kImageTag)]}; diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD index 0e72878ab..023a1f286 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/text/BUILD @@ -48,7 +48,6 @@ android_library( deps = [ "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/text/text_classifier/proto:text_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 289e3000d..72cee133f 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -97,7 +97,6 @@ android_library( "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", - "//mediapipe/tasks/cc/components/containers/proto:category_java_proto_lite", "//mediapipe/tasks/cc/components/containers/proto:classifications_java_proto_lite", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index d931c26c7..9d275e167 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -68,7 +68,7 @@ py_library( name = "category", srcs = ["category.py"], deps = [ - "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2", + "//mediapipe/framework/formats:classification_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", ], ) diff --git a/mediapipe/tasks/python/components/containers/category.py b/mediapipe/tasks/python/components/containers/category.py index cfdb83740..9b5419883 100644 --- a/mediapipe/tasks/python/components/containers/category.py +++ b/mediapipe/tasks/python/components/containers/category.py @@ -16,10 +16,10 @@ import dataclasses from typing import Any, Optional -from mediapipe.tasks.cc.components.containers.proto import category_pb2 +from mediapipe.framework.formats import classification_pb2 from mediapipe.tasks.python.core.optional_dependencies import doc_controls -_CategoryProto = category_pb2.Category +_ClassificationProto = classification_pb2.Classification @dataclasses.dataclass @@ -45,23 +45,23 @@ class Category: category_name: Optional[str] = None @doc_controls.do_not_generate_docs - def to_pb2(self) -> _CategoryProto: + def to_pb2(self) -> _ClassificationProto: """Generates a Category protobuf object.""" - return _CategoryProto( + return _ClassificationProto( index=self.index, score=self.score, - display_name=self.display_name, - category_name=self.category_name) + label=self.category_name, + display_name=self.display_name) @classmethod @doc_controls.do_not_generate_docs - def create_from_pb2(cls, pb2_obj: _CategoryProto) -> 'Category': + def create_from_pb2(cls, pb2_obj: _ClassificationProto) -> 'Category': """Creates a `Category` object from the given protobuf object.""" return Category( index=pb2_obj.index, score=pb2_obj.score, display_name=pb2_obj.display_name, - category_name=pb2_obj.category_name) + category_name=pb2_obj.label) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/classification_result.py b/mediapipe/tasks/python/components/containers/classification_result.py index 6ffdabe51..000468041 100644 --- a/mediapipe/tasks/python/components/containers/classification_result.py +++ b/mediapipe/tasks/python/components/containers/classification_result.py @@ -49,11 +49,7 @@ class Classifications: """Generates a Classifications protobuf object.""" classification_list_proto = _ClassificationListProto() for category in self.categories: - classification_proto = _ClassificationProto( - index=category.index, - score=category.score, - label=category.category_name, - display_name=category.display_name) + classification_proto = category.to_pb2() classification_list_proto.classification.append(classification_proto) return _ClassificationsProto( classification_list=classification_list_proto, @@ -65,14 +61,9 @@ class Classifications: def create_from_pb2(cls, pb2_obj: _ClassificationsProto) -> 'Classifications': """Creates a `Classifications` object from the given protobuf object.""" categories = [] - for entry in pb2_obj.classification_list.classification: + for classification in pb2_obj.classification_list.classification: categories.append( - category_module.Category( - index=entry.index, - score=entry.score, - display_name=entry.display_name, - category_name=entry.label)) - + category_module.Category.create_from_pb2(classification)) return Classifications( categories=categories, head_index=pb2_obj.head_index,