Add classification result structs and use then in ImageClassifier & TextClassifier.
PiperOrigin-RevId: 485567133
This commit is contained in:
		
							parent
							
								
									aab5f84aae
								
							
						
					
					
						commit
						d61ab92b90
					
				|  | @ -29,3 +29,23 @@ cc_library( | ||||||
|         "//mediapipe/framework/formats:landmark_cc_proto", |         "//mediapipe/framework/formats:landmark_cc_proto", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "category", | ||||||
|  |     srcs = ["category.cc"], | ||||||
|  |     hdrs = ["category.h"], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/framework/formats:classification_cc_proto", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "classification_result", | ||||||
|  |     srcs = ["classification_result.cc"], | ||||||
|  |     hdrs = ["classification_result.h"], | ||||||
|  |     deps = [ | ||||||
|  |         ":category", | ||||||
|  |         "//mediapipe/framework/formats:classification_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  |  | ||||||
							
								
								
									
										38
									
								
								mediapipe/tasks/cc/components/containers/category.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mediapipe/tasks/cc/components/containers/category.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,38 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/category.h" | ||||||
|  | 
 | ||||||
|  | #include <optional> | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/framework/formats/classification.pb.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe::tasks::components::containers { | ||||||
|  | 
 | ||||||
|  | Category ConvertToCategory(const mediapipe::Classification& proto) { | ||||||
|  |   Category category; | ||||||
|  |   category.index = proto.index(); | ||||||
|  |   category.score = proto.score(); | ||||||
|  |   if (proto.has_label()) { | ||||||
|  |     category.category_name = proto.label(); | ||||||
|  |   } | ||||||
|  |   if (proto.has_display_name()) { | ||||||
|  |     category.display_name = proto.display_name(); | ||||||
|  |   } | ||||||
|  |   return category; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace mediapipe::tasks::components::containers
 | ||||||
							
								
								
									
										52
									
								
								mediapipe/tasks/cc/components/containers/category.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								mediapipe/tasks/cc/components/containers/category.h
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,52 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_ | ||||||
|  | #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_ | ||||||
|  | 
 | ||||||
|  | #include <optional> | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/framework/formats/classification.pb.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe::tasks::components::containers { | ||||||
|  | 
 | ||||||
|  | // Defines a single classification result.
 | ||||||
|  | //
 | ||||||
|  | // The label maps packed into the TFLite Model Metadata [1] are used to populate
 | ||||||
|  | // the 'category_name' and 'display_name' fields.
 | ||||||
|  | //
 | ||||||
|  | // [1]: https://www.tensorflow.org/lite/convert/metadata
 | ||||||
|  | struct Category { | ||||||
|  |   // The index of the category in the classification model output.
 | ||||||
|  |   int index; | ||||||
|  |   // The score for this category, e.g. (but not necessarily) a probability in
 | ||||||
|  |   // [0,1].
 | ||||||
|  |   float score; | ||||||
|  |   // The optional ID for the category, read from the label map packed in the
 | ||||||
|  |   // TFLite Model Metadata if present. Not necessarily human-readable.
 | ||||||
|  |   std::optional<std::string> category_name = std::nullopt; | ||||||
|  |   // The optional human-readable name for the category, read from the label map
 | ||||||
|  |   // packed in the TFLite Model Metadata if present.
 | ||||||
|  |   std::optional<std::string> display_name = std::nullopt; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // Utility function to convert from mediapipe::Classification proto to Category
 | ||||||
|  | // struct.
 | ||||||
|  | Category ConvertToCategory(const mediapipe::Classification& proto); | ||||||
|  | 
 | ||||||
|  | }  // namespace mediapipe::tasks::components::containers
 | ||||||
|  | 
 | ||||||
|  | #endif  // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
 | ||||||
|  | @ -0,0 +1,57 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||||
|  | 
 | ||||||
|  | #include <optional> | ||||||
|  | #include <string> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/framework/formats/classification.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/category.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe::tasks::components::containers { | ||||||
|  | 
 | ||||||
|  | Classifications ConvertToClassifications(const proto::Classifications& proto) { | ||||||
|  |   Classifications classifications; | ||||||
|  |   classifications.categories.reserve( | ||||||
|  |       proto.classification_list().classification_size()); | ||||||
|  |   for (const auto& classification : | ||||||
|  |        proto.classification_list().classification()) { | ||||||
|  |     classifications.categories.push_back(ConvertToCategory(classification)); | ||||||
|  |   } | ||||||
|  |   classifications.head_index = proto.head_index(); | ||||||
|  |   if (proto.has_head_name()) { | ||||||
|  |     classifications.head_name = proto.head_name(); | ||||||
|  |   } | ||||||
|  |   return classifications; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ClassificationResult ConvertToClassificationResult( | ||||||
|  |     const proto::ClassificationResult& proto) { | ||||||
|  |   ClassificationResult classification_result; | ||||||
|  |   classification_result.classifications.reserve(proto.classifications_size()); | ||||||
|  |   for (const auto& classifications : proto.classifications()) { | ||||||
|  |     classification_result.classifications.push_back( | ||||||
|  |         ConvertToClassifications(classifications)); | ||||||
|  |   } | ||||||
|  |   if (proto.has_timestamp_ms()) { | ||||||
|  |     classification_result.timestamp_ms = proto.timestamp_ms(); | ||||||
|  |   } | ||||||
|  |   return classification_result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace mediapipe::tasks::components::containers
 | ||||||
|  | @ -0,0 +1,68 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ | ||||||
|  | #define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_ | ||||||
|  | 
 | ||||||
|  | #include <optional> | ||||||
|  | #include <string> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/category.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe::tasks::components::containers { | ||||||
|  | 
 | ||||||
|  | // Defines classification results for a given classifier head.
 | ||||||
|  | struct Classifications { | ||||||
|  |   // The array of predicted categories, usually sorted by descending scores,
 | ||||||
|  |   // e.g. from high to low probability.
 | ||||||
|  |   std::vector<Category> categories; | ||||||
|  |   // The index of the classifier head (i.e. output tensor) these categories
 | ||||||
|  |   // refer to. This is useful for multi-head models.
 | ||||||
|  |   int head_index; | ||||||
|  |   // The optional name of the classifier head, as provided in the TFLite Model
 | ||||||
|  |   // Metadata [1] if present. This is useful for multi-head models.
 | ||||||
|  |   //
 | ||||||
|  |   // [1]: https://www.tensorflow.org/lite/convert/metadata
 | ||||||
|  |   std::optional<std::string> head_name = std::nullopt; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // Defines classification results of a model.
 | ||||||
|  | struct ClassificationResult { | ||||||
|  |   // The classification results for each head of the model.
 | ||||||
|  |   std::vector<Classifications> classifications; | ||||||
|  |   // The optional timestamp (in milliseconds) of the start of the chunk of data
 | ||||||
|  |   // corresponding to these results.
 | ||||||
|  |   //
 | ||||||
|  |   // This is only used for classification on time series (e.g. audio
 | ||||||
|  |   // classification). In these use cases, the amount of data to process might
 | ||||||
|  |   // exceed the maximum size that the model can process: to solve this, the
 | ||||||
|  |   // input data is split into multiple chunks starting at different timestamps.
 | ||||||
|  |   std::optional<int64_t> timestamp_ms = std::nullopt; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // Utility function to convert from Classifications proto to
 | ||||||
|  | // Classifications struct.
 | ||||||
|  | Classifications ConvertToClassifications(const proto::Classifications& proto); | ||||||
|  | 
 | ||||||
|  | // Utility function to convert from ClassificationResult proto to
 | ||||||
|  | // ClassificationResult struct.
 | ||||||
|  | ClassificationResult ConvertToClassificationResult( | ||||||
|  |     const proto::ClassificationResult& proto); | ||||||
|  | 
 | ||||||
|  | }  // namespace mediapipe::tasks::components::containers
 | ||||||
|  | 
 | ||||||
|  | #endif  // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
 | ||||||
|  | @ -49,6 +49,8 @@ cc_library( | ||||||
|         ":text_classifier_graph", |         ":text_classifier_graph", | ||||||
|         "//mediapipe/framework:packet", |         "//mediapipe/framework:packet", | ||||||
|         "//mediapipe/framework/api2:builder", |         "//mediapipe/framework/api2:builder", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers:category", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers:classification_result", | ||||||
|         "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", |         "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", | ||||||
|         "//mediapipe/tasks/cc/components/processors:classifier_options", |         "//mediapipe/tasks/cc/components/processors:classifier_options", | ||||||
|         "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", |         "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", | ||||||
|  | @ -76,7 +78,8 @@ cc_test( | ||||||
|         "//mediapipe/framework/deps:file_path", |         "//mediapipe/framework/deps:file_path", | ||||||
|         "//mediapipe/framework/port:gtest_main", |         "//mediapipe/framework/port:gtest_main", | ||||||
|         "//mediapipe/tasks/cc:common", |         "//mediapipe/tasks/cc:common", | ||||||
|         "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", |         "//mediapipe/tasks/cc/components/containers:category", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers:classification_result", | ||||||
|         "@com_google_absl//absl/flags:flag", |         "@com_google_absl//absl/flags:flag", | ||||||
|         "@com_google_absl//absl/status", |         "@com_google_absl//absl/status", | ||||||
|         "@com_google_absl//absl/status:statusor", |         "@com_google_absl//absl/status:statusor", | ||||||
|  |  | ||||||
|  | @ -24,6 +24,7 @@ limitations under the License. | ||||||
| #include "absl/strings/string_view.h" | #include "absl/strings/string_view.h" | ||||||
| #include "mediapipe/framework/api2/builder.h" | #include "mediapipe/framework/api2/builder.h" | ||||||
| #include "mediapipe/framework/packet.h" | #include "mediapipe/framework/packet.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" | #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" | ||||||
| #include "mediapipe/tasks/cc/core/task_api_factory.h" | #include "mediapipe/tasks/cc/core/task_api_factory.h" | ||||||
|  | @ -37,12 +38,13 @@ namespace text_classifier { | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; | ||||||
| using ::mediapipe::tasks::components::containers::proto::ClassificationResult; | using ::mediapipe::tasks::components::containers::proto::ClassificationResult; | ||||||
| 
 | 
 | ||||||
| constexpr char kTextStreamName[] = "text_in"; | constexpr char kTextStreamName[] = "text_in"; | ||||||
| constexpr char kTextTag[] = "TEXT"; | constexpr char kTextTag[] = "TEXT"; | ||||||
| constexpr char kClassificationResultStreamName[] = "classification_result_out"; | constexpr char kClassificationsStreamName[] = "classifications_out"; | ||||||
| constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; | ||||||
| constexpr char kSubgraphTypeName[] = | constexpr char kSubgraphTypeName[] = | ||||||
|     "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; |     "mediapipe.tasks.text.text_classifier.TextClassifierGraph"; | ||||||
| 
 | 
 | ||||||
|  | @ -54,9 +56,8 @@ CalculatorGraphConfig CreateGraphConfig( | ||||||
|   auto& subgraph = graph.AddNode(kSubgraphTypeName); |   auto& subgraph = graph.AddNode(kSubgraphTypeName); | ||||||
|   subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get()); |   subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get()); | ||||||
|   graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); |   graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag); | ||||||
|   subgraph.Out(kClassificationResultTag) |   subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >> | ||||||
|           .SetName(kClassificationResultStreamName) >> |       graph.Out(kClassificationsTag); | ||||||
|       graph.Out(kClassificationResultTag); |  | ||||||
|   return graph.GetConfig(); |   return graph.GetConfig(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -88,14 +89,14 @@ absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create( | ||||||
|       std::move(options->base_options.op_resolver)); |       std::move(options->base_options.op_resolver)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::StatusOr<ClassificationResult> TextClassifier::Classify( | absl::StatusOr<TextClassifierResult> TextClassifier::Classify( | ||||||
|     absl::string_view text) { |     absl::string_view text) { | ||||||
|   ASSIGN_OR_RETURN( |   ASSIGN_OR_RETURN( | ||||||
|       auto output_packets, |       auto output_packets, | ||||||
|       runner_->Process( |       runner_->Process( | ||||||
|           {{kTextStreamName, MakePacket<std::string>(std::string(text))}})); |           {{kTextStreamName, MakePacket<std::string>(std::string(text))}})); | ||||||
|   return output_packets[kClassificationResultStreamName] |   return ConvertToClassificationResult( | ||||||
|       .Get<ClassificationResult>(); |       output_packets[kClassificationsStreamName].Get<ClassificationResult>()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace text_classifier
 | }  // namespace text_classifier
 | ||||||
|  |  | ||||||
|  | @ -21,7 +21,7 @@ limitations under the License. | ||||||
| #include "absl/status/status.h" | #include "absl/status/status.h" | ||||||
| #include "absl/status/statusor.h" | #include "absl/status/statusor.h" | ||||||
| #include "absl/strings/string_view.h" | #include "absl/strings/string_view.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||||
| #include "mediapipe/tasks/cc/components/processors/classifier_options.h" | #include "mediapipe/tasks/cc/components/processors/classifier_options.h" | ||||||
| #include "mediapipe/tasks/cc/core/base_options.h" | #include "mediapipe/tasks/cc/core/base_options.h" | ||||||
| #include "mediapipe/tasks/cc/core/base_task_api.h" | #include "mediapipe/tasks/cc/core/base_task_api.h" | ||||||
|  | @ -31,6 +31,10 @@ namespace tasks { | ||||||
| namespace text { | namespace text { | ||||||
| namespace text_classifier { | namespace text_classifier { | ||||||
| 
 | 
 | ||||||
|  | // Alias the shared ClassificationResult struct as result type.
 | ||||||
|  | using TextClassifierResult = | ||||||
|  |     ::mediapipe::tasks::components::containers::ClassificationResult; | ||||||
|  | 
 | ||||||
| // The options for configuring a MediaPipe text classifier task.
 | // The options for configuring a MediaPipe text classifier task.
 | ||||||
| struct TextClassifierOptions { | struct TextClassifierOptions { | ||||||
|   // Base options for configuring MediaPipe Tasks, such as specifying the model
 |   // Base options for configuring MediaPipe Tasks, such as specifying the model
 | ||||||
|  | @ -81,8 +85,7 @@ class TextClassifier : core::BaseTaskApi { | ||||||
|       std::unique_ptr<TextClassifierOptions> options); |       std::unique_ptr<TextClassifierOptions> options); | ||||||
| 
 | 
 | ||||||
|   // Performs classification on the input `text`.
 |   // Performs classification on the input `text`.
 | ||||||
|   absl::StatusOr<components::containers::proto::ClassificationResult> Classify( |   absl::StatusOr<TextClassifierResult> Classify(absl::string_view text); | ||||||
|       absl::string_view text); |  | ||||||
| 
 | 
 | ||||||
|   // Shuts down the TextClassifier when all the work is done.
 |   // Shuts down the TextClassifier when all the work is done.
 | ||||||
|   absl::Status Close() { return runner_->Close(); } |   absl::Status Close() { return runner_->Close(); } | ||||||
|  |  | ||||||
|  | @ -47,10 +47,18 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; | ||||||
| using ::mediapipe::tasks::core::ModelResources; | using ::mediapipe::tasks::core::ModelResources; | ||||||
| 
 | 
 | ||||||
| constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | ||||||
|  | constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; | ||||||
| constexpr char kTextTag[] = "TEXT"; | constexpr char kTextTag[] = "TEXT"; | ||||||
| constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; | constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR"; | ||||||
| constexpr char kTensorsTag[] = "TENSORS"; | constexpr char kTensorsTag[] = "TENSORS"; | ||||||
| 
 | 
 | ||||||
|  | // TODO: remove once Java API migration is over.
 | ||||||
|  | // Struct holding the different output streams produced by the text classifier.
 | ||||||
|  | struct TextClassifierOutputStreams { | ||||||
|  |   Source<ClassificationResult> classification_result; | ||||||
|  |   Source<ClassificationResult> classifications; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| // A "TextClassifierGraph" performs Natural Language classification (including
 | // A "TextClassifierGraph" performs Natural Language classification (including
 | ||||||
|  | @ -62,7 +70,10 @@ constexpr char kTensorsTag[] = "TENSORS"; | ||||||
| //     Input text to perform classification on.
 | //     Input text to perform classification on.
 | ||||||
| //
 | //
 | ||||||
| // Outputs:
 | // Outputs:
 | ||||||
| //   CLASSIFICATION_RESULT - ClassificationResult
 | //   CLASSIFICATIONS - ClassificationResult @Optional
 | ||||||
|  | //     The classification results aggregated by classifier head.
 | ||||||
|  | // TODO: remove once Java API migration is over.
 | ||||||
|  | //   CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
 | ||||||
| //     The aggregated classification result object that has 3 dimensions:
 | //     The aggregated classification result object that has 3 dimensions:
 | ||||||
| //     (classification head, classification timestamp, classification category).
 | //     (classification head, classification timestamp, classification category).
 | ||||||
| //
 | //
 | ||||||
|  | @ -70,7 +81,7 @@ constexpr char kTensorsTag[] = "TENSORS"; | ||||||
| // node {
 | // node {
 | ||||||
| //   calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
 | //   calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
 | ||||||
| //   input_stream: "TEXT:text_in"
 | //   input_stream: "TEXT:text_in"
 | ||||||
| //   output_stream: "CLASSIFICATION_RESULT:classification_result_out"
 | //   output_stream: "CLASSIFICATIONS:classifications_out"
 | ||||||
| //   options {
 | //   options {
 | ||||||
| //     [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
 | //     [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
 | ||||||
| //     {
 | //     {
 | ||||||
|  | @ -91,12 +102,14 @@ class TextClassifierGraph : public core::ModelTaskGraph { | ||||||
|         CreateModelResources<proto::TextClassifierGraphOptions>(sc)); |         CreateModelResources<proto::TextClassifierGraphOptions>(sc)); | ||||||
|     Graph graph; |     Graph graph; | ||||||
|     ASSIGN_OR_RETURN( |     ASSIGN_OR_RETURN( | ||||||
|         Source<ClassificationResult> classification_result_out, |         auto output_streams, | ||||||
|         BuildTextClassifierTask( |         BuildTextClassifierTask( | ||||||
|             sc->Options<proto::TextClassifierGraphOptions>(), *model_resources, |             sc->Options<proto::TextClassifierGraphOptions>(), *model_resources, | ||||||
|             graph[Input<std::string>(kTextTag)], graph)); |             graph[Input<std::string>(kTextTag)], graph)); | ||||||
|     classification_result_out >> |     output_streams.classification_result >> | ||||||
|         graph[Output<ClassificationResult>(kClassificationResultTag)]; |         graph[Output<ClassificationResult>(kClassificationResultTag)]; | ||||||
|  |     output_streams.classifications >> | ||||||
|  |         graph[Output<ClassificationResult>(kClassificationsTag)]; | ||||||
|     return graph.GetConfig(); |     return graph.GetConfig(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -111,7 +124,7 @@ class TextClassifierGraph : public core::ModelTaskGraph { | ||||||
|   //   TextClassifier model file with model metadata.
 |   //   TextClassifier model file with model metadata.
 | ||||||
|   // text_in: (std::string) stream to run text classification on.
 |   // text_in: (std::string) stream to run text classification on.
 | ||||||
|   // graph: the mediapipe builder::Graph instance to be updated.
 |   // graph: the mediapipe builder::Graph instance to be updated.
 | ||||||
|   absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask( |   absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask( | ||||||
|       const proto::TextClassifierGraphOptions& task_options, |       const proto::TextClassifierGraphOptions& task_options, | ||||||
|       const ModelResources& model_resources, Source<std::string> text_in, |       const ModelResources& model_resources, Source<std::string> text_in, | ||||||
|       Graph& graph) { |       Graph& graph) { | ||||||
|  | @ -148,8 +161,11 @@ class TextClassifierGraph : public core::ModelTaskGraph { | ||||||
| 
 | 
 | ||||||
|     // Outputs the aggregated classification result as the subgraph output
 |     // Outputs the aggregated classification result as the subgraph output
 | ||||||
|     // stream.
 |     // stream.
 | ||||||
|     return postprocessing[Output<ClassificationResult>( |     return TextClassifierOutputStreams{ | ||||||
|         kClassificationResultTag)]; |         /*classification_result=*/postprocessing[Output<ClassificationResult>( | ||||||
|  |             kClassificationResultTag)], | ||||||
|  |         /*classifications=*/postprocessing[Output<ClassificationResult>( | ||||||
|  |             kClassificationsTag)]}; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -33,7 +33,8 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/port/gtest.h" | #include "mediapipe/framework/port/gtest.h" | ||||||
| #include "mediapipe/framework/port/status_matchers.h" | #include "mediapipe/framework/port/status_matchers.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/category.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||||
| #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" | #include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h" | ||||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||||
| 
 | 
 | ||||||
|  | @ -43,14 +44,13 @@ namespace text { | ||||||
| namespace text_classifier { | namespace text_classifier { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| using ::mediapipe::EqualsProto; |  | ||||||
| using ::mediapipe::file::JoinPath; | using ::mediapipe::file::JoinPath; | ||||||
| using ::mediapipe::tasks::kMediaPipeTasksPayload; | using ::mediapipe::tasks::kMediaPipeTasksPayload; | ||||||
| using ::mediapipe::tasks::components::containers::proto::ClassificationResult; | using ::mediapipe::tasks::components::containers::Category; | ||||||
|  | using ::mediapipe::tasks::components::containers::Classifications; | ||||||
| using ::testing::HasSubstr; | using ::testing::HasSubstr; | ||||||
| using ::testing::Optional; | using ::testing::Optional; | ||||||
| 
 | 
 | ||||||
| constexpr float kEpsilon = 0.001; |  | ||||||
| constexpr int kMaxSeqLen = 128; | constexpr int kMaxSeqLen = 128; | ||||||
| constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; | constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/"; | ||||||
| constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; | constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite"; | ||||||
|  | @ -64,6 +64,30 @@ std::string GetFullPath(absl::string_view file_name) { | ||||||
|   return JoinPath("./", kTestDataDirectory, file_name); |   return JoinPath("./", kTestDataDirectory, file_name); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Checks that the two provided `TextClassifierResult` are equal, with a
 | ||||||
|  | // tolerancy on floating-point score to account for numerical instabilities.
 | ||||||
|  | // TODO: create shared matcher for ClassificationResult.
 | ||||||
|  | void ExpectApproximatelyEqual(const TextClassifierResult& actual, | ||||||
|  |                               const TextClassifierResult& expected) { | ||||||
|  |   const float kPrecision = 1e-6; | ||||||
|  |   ASSERT_EQ(actual.classifications.size(), expected.classifications.size()); | ||||||
|  |   for (int i = 0; i < actual.classifications.size(); ++i) { | ||||||
|  |     const Classifications& a = actual.classifications[i]; | ||||||
|  |     const Classifications& b = expected.classifications[i]; | ||||||
|  |     EXPECT_EQ(a.head_index, b.head_index); | ||||||
|  |     EXPECT_EQ(a.head_name, b.head_name); | ||||||
|  |     EXPECT_EQ(a.categories.size(), b.categories.size()); | ||||||
|  |     for (int j = 0; j < a.categories.size(); ++j) { | ||||||
|  |       const Category& x = a.categories[j]; | ||||||
|  |       const Category& y = b.categories[j]; | ||||||
|  |       EXPECT_EQ(x.index, y.index); | ||||||
|  |       EXPECT_NEAR(x.score, y.score, kPrecision); | ||||||
|  |       EXPECT_EQ(x.category_name, y.category_name); | ||||||
|  |       EXPECT_EQ(x.display_name, y.display_name); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| class TextClassifierTest : public tflite_shims::testing::Test {}; | class TextClassifierTest : public tflite_shims::testing::Test {}; | ||||||
| 
 | 
 | ||||||
| TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { | TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) { | ||||||
|  | @ -107,6 +131,92 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) { | ||||||
|   MP_ASSERT_OK(TextClassifier::Create(std::move(options))); |   MP_ASSERT_OK(TextClassifier::Create(std::move(options))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(TextClassifierTest, TextClassifierWithBert) { | ||||||
|  |   auto options = std::make_unique<TextClassifierOptions>(); | ||||||
|  |   options->base_options.model_asset_path = GetFullPath(kTestBertModelPath); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier, | ||||||
|  |                           TextClassifier::Create(std::move(options))); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       TextClassifierResult negative_result, | ||||||
|  |       classifier->Classify("unflinchingly bleak and desperate")); | ||||||
|  |   TextClassifierResult negative_expected; | ||||||
|  |   negative_expected.classifications.emplace_back(Classifications{ | ||||||
|  |       /*categories=*/{ | ||||||
|  |           {/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"}, | ||||||
|  |           {/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}}, | ||||||
|  |       /*head_index=*/0, | ||||||
|  |       /*head_name=*/"probability"}); | ||||||
|  |   ExpectApproximatelyEqual(negative_result, negative_expected); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       TextClassifierResult positive_result, | ||||||
|  |       classifier->Classify("it's a charming and often affecting journey")); | ||||||
|  |   TextClassifierResult positive_expected; | ||||||
|  |   positive_expected.classifications.emplace_back(Classifications{ | ||||||
|  |       /*categories=*/{ | ||||||
|  |           {/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"}, | ||||||
|  |           {/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}}, | ||||||
|  |       /*head_index=*/0, | ||||||
|  |       /*head_name=*/"probability"}); | ||||||
|  |   ExpectApproximatelyEqual(positive_result, positive_expected); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(classifier->Close()); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(TextClassifierTest, TextClassifierWithIntInputs) { | ||||||
|  |   auto options = std::make_unique<TextClassifierOptions>(); | ||||||
|  |   options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier, | ||||||
|  |                           TextClassifier::Create(std::move(options))); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result, | ||||||
|  |                           classifier->Classify("What a waste of my time.")); | ||||||
|  |   TextClassifierResult negative_expected; | ||||||
|  |   negative_expected.classifications.emplace_back(Classifications{ | ||||||
|  |       /*categories=*/{ | ||||||
|  |           {/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"}, | ||||||
|  |           {/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}}, | ||||||
|  |       /*head_index=*/0, | ||||||
|  |       /*head_name=*/"probability"}); | ||||||
|  |   ExpectApproximatelyEqual(negative_result, negative_expected); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       TextClassifierResult positive_result, | ||||||
|  |       classifier->Classify("This is the best movie I’ve seen in recent years." | ||||||
|  |                            "Strongly recommend it!")); | ||||||
|  |   TextClassifierResult positive_expected; | ||||||
|  |   positive_expected.classifications.emplace_back(Classifications{ | ||||||
|  |       /*categories=*/{ | ||||||
|  |           {/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"}, | ||||||
|  |           {/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}}, | ||||||
|  |       /*head_index=*/0, | ||||||
|  |       /*head_name=*/"probability"}); | ||||||
|  |   ExpectApproximatelyEqual(positive_result, positive_expected); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(classifier->Close()); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(TextClassifierTest, TextClassifierWithStringToBool) { | ||||||
|  |   auto options = std::make_unique<TextClassifierOptions>(); | ||||||
|  |   options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath); | ||||||
|  |   options->base_options.op_resolver = CreateCustomResolver(); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier, | ||||||
|  |                           TextClassifier::Create(std::move(options))); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result, | ||||||
|  |                           classifier->Classify("hello")); | ||||||
|  | 
 | ||||||
|  |   // Binary outputs causes flaky ordering, so we compare manually.
 | ||||||
|  |   ASSERT_EQ(result.classifications.size(), 1); | ||||||
|  |   ASSERT_EQ(result.classifications[0].head_index, 0); | ||||||
|  |   ASSERT_EQ(result.classifications[0].categories.size(), 3); | ||||||
|  |   ASSERT_EQ(result.classifications[0].categories[0].score, 1); | ||||||
|  |   ASSERT_LT(result.classifications[0].categories[0].index, 2);  // i.e O or 1.
 | ||||||
|  |   ASSERT_EQ(result.classifications[0].categories[1].score, 1); | ||||||
|  |   ASSERT_LT(result.classifications[0].categories[1].index, 2);  // i.e 0 or 1.
 | ||||||
|  |   ASSERT_EQ(result.classifications[0].categories[2].score, 0); | ||||||
|  |   ASSERT_EQ(result.classifications[0].categories[2].index, 2); | ||||||
|  |   MP_ASSERT_OK(classifier->Close()); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace text_classifier
 | }  // namespace text_classifier
 | ||||||
| }  // namespace text
 | }  // namespace text
 | ||||||
|  |  | ||||||
|  | @ -50,6 +50,7 @@ cc_library( | ||||||
|         "//mediapipe/framework/api2:builder", |         "//mediapipe/framework/api2:builder", | ||||||
|         "//mediapipe/framework/formats:image", |         "//mediapipe/framework/formats:image", | ||||||
|         "//mediapipe/framework/formats:rect_cc_proto", |         "//mediapipe/framework/formats:rect_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers:classification_result", | ||||||
|         "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", |         "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", | ||||||
|         "//mediapipe/tasks/cc/components/processors:classifier_options", |         "//mediapipe/tasks/cc/components/processors:classifier_options", | ||||||
|         "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", |         "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", | ||||||
|  |  | ||||||
|  | @ -26,6 +26,7 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/formats/rect.pb.h" | #include "mediapipe/framework/formats/rect.pb.h" | ||||||
| #include "mediapipe/framework/packet.h" | #include "mediapipe/framework/packet.h" | ||||||
| #include "mediapipe/framework/timestamp.h" | #include "mediapipe/framework/timestamp.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/processors/classifier_options.h" | #include "mediapipe/tasks/cc/components/processors/classifier_options.h" | ||||||
| #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" | #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" | ||||||
|  | @ -46,8 +47,8 @@ namespace image_classifier { | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| constexpr char kClassificationResultStreamName[] = "classification_result_out"; | constexpr char kClassificationsStreamName[] = "classifications_out"; | ||||||
| constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; | ||||||
| constexpr char kImageInStreamName[] = "image_in"; | constexpr char kImageInStreamName[] = "image_in"; | ||||||
| constexpr char kImageOutStreamName[] = "image_out"; | constexpr char kImageOutStreamName[] = "image_out"; | ||||||
| constexpr char kImageTag[] = "IMAGE"; | constexpr char kImageTag[] = "IMAGE"; | ||||||
|  | @ -57,6 +58,7 @@ constexpr char kSubgraphTypeName[] = | ||||||
|     "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; |     "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; | ||||||
| constexpr int kMicroSecondsPerMilliSecond = 1000; | constexpr int kMicroSecondsPerMilliSecond = 1000; | ||||||
| 
 | 
 | ||||||
|  | using ::mediapipe::tasks::components::containers::ConvertToClassificationResult; | ||||||
| using ::mediapipe::tasks::components::containers::proto::ClassificationResult; | using ::mediapipe::tasks::components::containers::proto::ClassificationResult; | ||||||
| using ::mediapipe::tasks::core::PacketMap; | using ::mediapipe::tasks::core::PacketMap; | ||||||
| 
 | 
 | ||||||
|  | @ -73,15 +75,13 @@ CalculatorGraphConfig CreateGraphConfig( | ||||||
|   auto& task_subgraph = graph.AddNode(kSubgraphTypeName); |   auto& task_subgraph = graph.AddNode(kSubgraphTypeName); | ||||||
|   task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap( |   task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap( | ||||||
|       options_proto.get()); |       options_proto.get()); | ||||||
|   task_subgraph.Out(kClassificationResultTag) |   task_subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >> | ||||||
|           .SetName(kClassificationResultStreamName) >> |       graph.Out(kClassificationsTag); | ||||||
|       graph.Out(kClassificationResultTag); |  | ||||||
|   task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> |   task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> | ||||||
|       graph.Out(kImageTag); |       graph.Out(kImageTag); | ||||||
|   if (enable_flow_limiting) { |   if (enable_flow_limiting) { | ||||||
|     return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph, |     return tasks::core::AddFlowLimiterCalculator( | ||||||
|                                                  {kImageTag, kNormRectTag}, |         graph, task_subgraph, {kImageTag, kNormRectTag}, kClassificationsTag); | ||||||
|                                                  kClassificationResultTag); |  | ||||||
|   } |   } | ||||||
|   graph.In(kImageTag) >> task_subgraph.In(kImageTag); |   graph.In(kImageTag) >> task_subgraph.In(kImageTag); | ||||||
|   graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); |   graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); | ||||||
|  | @ -125,13 +125,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create( | ||||||
|           if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { |           if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) { | ||||||
|             return; |             return; | ||||||
|           } |           } | ||||||
|           Packet classification_result_packet = |           Packet classifications_packet = | ||||||
|               status_or_packets.value()[kClassificationResultStreamName]; |               status_or_packets.value()[kClassificationsStreamName]; | ||||||
|           Packet image_packet = status_or_packets.value()[kImageOutStreamName]; |           Packet image_packet = status_or_packets.value()[kImageOutStreamName]; | ||||||
|           result_callback( |           result_callback( | ||||||
|               classification_result_packet.Get<ClassificationResult>(), |               ConvertToClassificationResult( | ||||||
|  |                   classifications_packet.Get<ClassificationResult>()), | ||||||
|               image_packet.Get<Image>(), |               image_packet.Get<Image>(), | ||||||
|               classification_result_packet.Timestamp().Value() / |               classifications_packet.Timestamp().Value() / | ||||||
|                   kMicroSecondsPerMilliSecond); |                   kMicroSecondsPerMilliSecond); | ||||||
|         }; |         }; | ||||||
|   } |   } | ||||||
|  | @ -144,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create( | ||||||
|       std::move(packets_callback)); |       std::move(packets_callback)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::StatusOr<ClassificationResult> ImageClassifier::Classify( | absl::StatusOr<ImageClassifierResult> ImageClassifier::Classify( | ||||||
|     Image image, |     Image image, | ||||||
|     std::optional<core::ImageProcessingOptions> image_processing_options) { |     std::optional<core::ImageProcessingOptions> image_processing_options) { | ||||||
|   if (image.UsesGpu()) { |   if (image.UsesGpu()) { | ||||||
|  | @ -160,11 +161,11 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify( | ||||||
|       ProcessImageData( |       ProcessImageData( | ||||||
|           {{kImageInStreamName, MakePacket<Image>(std::move(image))}, |           {{kImageInStreamName, MakePacket<Image>(std::move(image))}, | ||||||
|            {kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}})); |            {kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}})); | ||||||
|   return output_packets[kClassificationResultStreamName] |   return ConvertToClassificationResult( | ||||||
|       .Get<ClassificationResult>(); |       output_packets[kClassificationsStreamName].Get<ClassificationResult>()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo( | absl::StatusOr<ImageClassifierResult> ImageClassifier::ClassifyForVideo( | ||||||
|     Image image, int64 timestamp_ms, |     Image image, int64 timestamp_ms, | ||||||
|     std::optional<core::ImageProcessingOptions> image_processing_options) { |     std::optional<core::ImageProcessingOptions> image_processing_options) { | ||||||
|   if (image.UsesGpu()) { |   if (image.UsesGpu()) { | ||||||
|  | @ -184,8 +185,8 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo( | ||||||
|            {kNormRectName, |            {kNormRectName, | ||||||
|             MakePacket<NormalizedRect>(std::move(norm_rect)) |             MakePacket<NormalizedRect>(std::move(norm_rect)) | ||||||
|                 .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); |                 .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); | ||||||
|   return output_packets[kClassificationResultStreamName] |   return ConvertToClassificationResult( | ||||||
|       .Get<ClassificationResult>(); |       output_packets[kClassificationsStreamName].Get<ClassificationResult>()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::Status ImageClassifier::ClassifyAsync( | absl::Status ImageClassifier::ClassifyAsync( | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ limitations under the License. | ||||||
| 
 | 
 | ||||||
| #include "absl/status/statusor.h" | #include "absl/status/statusor.h" | ||||||
| #include "mediapipe/framework/formats/image.h" | #include "mediapipe/framework/formats/image.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||||
| #include "mediapipe/tasks/cc/components/processors/classifier_options.h" | #include "mediapipe/tasks/cc/components/processors/classifier_options.h" | ||||||
| #include "mediapipe/tasks/cc/core/base_options.h" | #include "mediapipe/tasks/cc/core/base_options.h" | ||||||
| #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" | #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" | ||||||
|  | @ -34,6 +34,10 @@ namespace tasks { | ||||||
| namespace vision { | namespace vision { | ||||||
| namespace image_classifier { | namespace image_classifier { | ||||||
| 
 | 
 | ||||||
|  | // Alias the shared ClassificationResult struct as result type.
 | ||||||
|  | using ImageClassifierResult = | ||||||
|  |     ::mediapipe::tasks::components::containers::ClassificationResult; | ||||||
|  | 
 | ||||||
| // The options for configuring a Mediapipe image classifier task.
 | // The options for configuring a Mediapipe image classifier task.
 | ||||||
| struct ImageClassifierOptions { | struct ImageClassifierOptions { | ||||||
|   // Base options for configuring MediaPipe Tasks, such as specifying the model
 |   // Base options for configuring MediaPipe Tasks, such as specifying the model
 | ||||||
|  | @ -56,9 +60,8 @@ struct ImageClassifierOptions { | ||||||
|   // The user-defined result callback for processing live stream data.
 |   // The user-defined result callback for processing live stream data.
 | ||||||
|   // The result callback should only be specified when the running mode is set
 |   // The result callback should only be specified when the running mode is set
 | ||||||
|   // to RunningMode::LIVE_STREAM.
 |   // to RunningMode::LIVE_STREAM.
 | ||||||
|   std::function<void( |   std::function<void(absl::StatusOr<ImageClassifierResult>, const Image&, | ||||||
|       absl::StatusOr<components::containers::proto::ClassificationResult>, |                      int64)> | ||||||
|       const Image&, int64)> |  | ||||||
|       result_callback = nullptr; |       result_callback = nullptr; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -122,7 +125,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { | ||||||
|   // The image can be of any size with format RGB or RGBA.
 |   // The image can be of any size with format RGB or RGBA.
 | ||||||
|   // TODO: describe exact preprocessing steps once
 |   // TODO: describe exact preprocessing steps once
 | ||||||
|   // YUVToImageCalculator is integrated.
 |   // YUVToImageCalculator is integrated.
 | ||||||
|   absl::StatusOr<components::containers::proto::ClassificationResult> Classify( |   absl::StatusOr<ImageClassifierResult> Classify( | ||||||
|       mediapipe::Image image, |       mediapipe::Image image, | ||||||
|       std::optional<core::ImageProcessingOptions> image_processing_options = |       std::optional<core::ImageProcessingOptions> image_processing_options = | ||||||
|           std::nullopt); |           std::nullopt); | ||||||
|  | @ -144,10 +147,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { | ||||||
|   // The image can be of any size with format RGB or RGBA. It's required to
 |   // The image can be of any size with format RGB or RGBA. It's required to
 | ||||||
|   // provide the video frame's timestamp (in milliseconds). The input timestamps
 |   // provide the video frame's timestamp (in milliseconds). The input timestamps
 | ||||||
|   // must be monotonically increasing.
 |   // must be monotonically increasing.
 | ||||||
|   absl::StatusOr<components::containers::proto::ClassificationResult> |   absl::StatusOr<ImageClassifierResult> ClassifyForVideo( | ||||||
|   ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms, |       mediapipe::Image image, int64 timestamp_ms, | ||||||
|                    std::optional<core::ImageProcessingOptions> |       std::optional<core::ImageProcessingOptions> image_processing_options = | ||||||
|                        image_processing_options = std::nullopt); |           std::nullopt); | ||||||
| 
 | 
 | ||||||
|   // Sends live image data to image classification, and the results will be
 |   // Sends live image data to image classification, and the results will be
 | ||||||
|   // available via the "result_callback" provided in the ImageClassifierOptions.
 |   // available via the "result_callback" provided in the ImageClassifierOptions.
 | ||||||
|  | @ -170,7 +173,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi { | ||||||
|   // increasing.
 |   // increasing.
 | ||||||
|   //
 |   //
 | ||||||
|   // The "result_callback" provides:
 |   // The "result_callback" provides:
 | ||||||
|   //   - The classification results as a ClassificationResult object.
 |   //   - The classification results as an ImageClassifierResult object.
 | ||||||
|   //   - The const reference to the corresponding input image that the image
 |   //   - The const reference to the corresponding input image that the image
 | ||||||
|   //     classifier runs on. Note that the const reference to the image will no
 |   //     classifier runs on. Note that the const reference to the image will no
 | ||||||
|   //     longer be valid when the callback returns. To access the image data
 |   //     longer be valid when the callback returns. To access the image data
 | ||||||
|  |  | ||||||
|  | @ -48,6 +48,7 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult; | ||||||
| constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); | constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); | ||||||
| 
 | 
 | ||||||
| constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | ||||||
|  | constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; | ||||||
| constexpr char kImageTag[] = "IMAGE"; | constexpr char kImageTag[] = "IMAGE"; | ||||||
| constexpr char kNormRectTag[] = "NORM_RECT"; | constexpr char kNormRectTag[] = "NORM_RECT"; | ||||||
| constexpr char kTensorsTag[] = "TENSORS"; | constexpr char kTensorsTag[] = "TENSORS"; | ||||||
|  | @ -56,6 +57,7 @@ constexpr char kTensorsTag[] = "TENSORS"; | ||||||
| // subgraph.
 | // subgraph.
 | ||||||
| struct ImageClassifierOutputStreams { | struct ImageClassifierOutputStreams { | ||||||
|   Source<ClassificationResult> classification_result; |   Source<ClassificationResult> classification_result; | ||||||
|  |   Source<ClassificationResult> classifications; | ||||||
|   Source<Image> image; |   Source<Image> image; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -71,17 +73,19 @@ struct ImageClassifierOutputStreams { | ||||||
| //     Describes region of image to perform classification on.
 | //     Describes region of image to perform classification on.
 | ||||||
| //     @Optional: rect covering the whole image is used if not specified.
 | //     @Optional: rect covering the whole image is used if not specified.
 | ||||||
| // Outputs:
 | // Outputs:
 | ||||||
| //   CLASSIFICATION_RESULT - ClassificationResult
 | //   CLASSIFICATIONS - ClassificationResult @Optional
 | ||||||
| //     The aggregated classification result object has two dimensions:
 | //     The classification results aggregated by classifier head.
 | ||||||
| //     (classification head, classification category)
 |  | ||||||
| //   IMAGE - Image
 | //   IMAGE - Image
 | ||||||
| //     The image that object detection runs on.
 | //     The image that object detection runs on.
 | ||||||
|  | // TODO: remove this output once Java API migration is over.
 | ||||||
|  | //   CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
 | ||||||
|  | //     The aggregated classification result.
 | ||||||
| //
 | //
 | ||||||
| // Example:
 | // Example:
 | ||||||
| // node {
 | // node {
 | ||||||
| //   calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
 | //   calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
 | ||||||
| //   input_stream: "IMAGE:image_in"
 | //   input_stream: "IMAGE:image_in"
 | ||||||
| //   output_stream: "CLASSIFICATION_RESULT:classification_result_out"
 | //   output_stream: "CLASSIFICATIONS:classifications_out"
 | ||||||
| //   output_stream: "IMAGE:image_out"
 | //   output_stream: "IMAGE:image_out"
 | ||||||
| //   options {
 | //   options {
 | ||||||
| //     [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
 | //     [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
 | ||||||
|  | @ -115,6 +119,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph { | ||||||
|             graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); |             graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); | ||||||
|     output_streams.classification_result >> |     output_streams.classification_result >> | ||||||
|         graph[Output<ClassificationResult>(kClassificationResultTag)]; |         graph[Output<ClassificationResult>(kClassificationResultTag)]; | ||||||
|  |     output_streams.classifications >> | ||||||
|  |         graph[Output<ClassificationResult>(kClassificationsTag)]; | ||||||
|     output_streams.image >> graph[Output<Image>(kImageTag)]; |     output_streams.image >> graph[Output<Image>(kImageTag)]; | ||||||
|     return graph.GetConfig(); |     return graph.GetConfig(); | ||||||
|   } |   } | ||||||
|  | @ -170,6 +176,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph { | ||||||
|     return ImageClassifierOutputStreams{ |     return ImageClassifierOutputStreams{ | ||||||
|         /*classification_result=*/postprocessing[Output<ClassificationResult>( |         /*classification_result=*/postprocessing[Output<ClassificationResult>( | ||||||
|             kClassificationResultTag)], |             kClassificationResultTag)], | ||||||
|  |         /*classifications=*/ | ||||||
|  |         postprocessing[Output<ClassificationResult>(kClassificationsTag)], | ||||||
|         /*image=*/preprocessing[Output<Image>(kImageTag)]}; |         /*image=*/preprocessing[Output<Image>(kImageTag)]}; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | @ -32,8 +32,8 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/port/parse_text_proto.h" | #include "mediapipe/framework/port/parse_text_proto.h" | ||||||
| #include "mediapipe/framework/port/status_matchers.h" | #include "mediapipe/framework/port/status_matchers.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h" | #include "mediapipe/tasks/cc/components/containers/category.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/classification_result.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/rect.h" | #include "mediapipe/tasks/cc/components/containers/rect.h" | ||||||
| #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" | #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" | ||||||
| #include "mediapipe/tasks/cc/vision/core/running_mode.h" | #include "mediapipe/tasks/cc/vision/core/running_mode.h" | ||||||
|  | @ -50,10 +50,9 @@ namespace image_classifier { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| using ::mediapipe::file::JoinPath; | using ::mediapipe::file::JoinPath; | ||||||
|  | using ::mediapipe::tasks::components::containers::Category; | ||||||
|  | using ::mediapipe::tasks::components::containers::Classifications; | ||||||
| using ::mediapipe::tasks::components::containers::Rect; | using ::mediapipe::tasks::components::containers::Rect; | ||||||
| using ::mediapipe::tasks::components::containers::proto::ClassificationEntry; |  | ||||||
| using ::mediapipe::tasks::components::containers::proto::ClassificationResult; |  | ||||||
| using ::mediapipe::tasks::components::containers::proto::Classifications; |  | ||||||
| using ::mediapipe::tasks::vision::core::ImageProcessingOptions; | using ::mediapipe::tasks::vision::core::ImageProcessingOptions; | ||||||
| using ::testing::HasSubstr; | using ::testing::HasSubstr; | ||||||
| using ::testing::Optional; | using ::testing::Optional; | ||||||
|  | @ -65,83 +64,56 @@ constexpr char kMobileNetQuantizedWithMetadata[] = | ||||||
| constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] = | constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] = | ||||||
|     "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite"; |     "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite"; | ||||||
| 
 | 
 | ||||||
| // Checks that the two provided `ClassificationResult` are equal, with a
 | // Checks that the two provided `ImageClassifierResult` are equal, with a
 | ||||||
| // tolerancy on floating-point score to account for numerical instabilities.
 | // tolerancy on floating-point score to account for numerical instabilities.
 | ||||||
| void ExpectApproximatelyEqual(const ClassificationResult& actual, | void ExpectApproximatelyEqual(const ImageClassifierResult& actual, | ||||||
|                               const ClassificationResult& expected) { |                               const ImageClassifierResult& expected) { | ||||||
|   const float kPrecision = 1e-6; |   const float kPrecision = 1e-6; | ||||||
|   ASSERT_EQ(actual.classifications_size(), expected.classifications_size()); |   ASSERT_EQ(actual.classifications.size(), expected.classifications.size()); | ||||||
|   for (int i = 0; i < actual.classifications_size(); ++i) { |   for (int i = 0; i < actual.classifications.size(); ++i) { | ||||||
|     const Classifications& a = actual.classifications(i); |     const Classifications& a = actual.classifications[i]; | ||||||
|     const Classifications& b = expected.classifications(i); |     const Classifications& b = expected.classifications[i]; | ||||||
|     EXPECT_EQ(a.head_index(), b.head_index()); |     EXPECT_EQ(a.head_index, b.head_index); | ||||||
|     EXPECT_EQ(a.head_name(), b.head_name()); |     EXPECT_EQ(a.head_name, b.head_name); | ||||||
|     EXPECT_EQ(a.entries_size(), b.entries_size()); |     EXPECT_EQ(a.categories.size(), b.categories.size()); | ||||||
|     for (int j = 0; j < a.entries_size(); ++j) { |     for (int j = 0; j < a.categories.size(); ++j) { | ||||||
|       const ClassificationEntry& x = a.entries(j); |       const Category& x = a.categories[j]; | ||||||
|       const ClassificationEntry& y = b.entries(j); |       const Category& y = b.categories[j]; | ||||||
|       EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms()); |       EXPECT_EQ(x.index, y.index); | ||||||
|       EXPECT_EQ(x.categories_size(), y.categories_size()); |       EXPECT_NEAR(x.score, y.score, kPrecision); | ||||||
|       for (int k = 0; k < x.categories_size(); ++k) { |       EXPECT_EQ(x.category_name, y.category_name); | ||||||
|         EXPECT_EQ(x.categories(k).index(), y.categories(k).index()); |       EXPECT_EQ(x.display_name, y.display_name); | ||||||
|         EXPECT_EQ(x.categories(k).category_name(), |  | ||||||
|                   y.categories(k).category_name()); |  | ||||||
|         EXPECT_EQ(x.categories(k).display_name(), |  | ||||||
|                   y.categories(k).display_name()); |  | ||||||
|         EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(), |  | ||||||
|                     kPrecision); |  | ||||||
|       } |  | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
 | // Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
 | ||||||
| // with max_results set to 3.
 | // with max_results set to 3.
 | ||||||
| ClassificationResult GenerateBurgerResults(int64 timestamp) { | ImageClassifierResult GenerateBurgerResults() { | ||||||
|   return ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult result; | ||||||
|       absl::StrFormat(R"pb(classifications { |   result.classifications.emplace_back(Classifications{ | ||||||
|                              entries { |       /*categories=*/{ | ||||||
|                                categories { |           {/*index=*/934, /*score=*/0.793959200, | ||||||
|                                  index: 934 |            /*category_name=*/"cheeseburger"}, | ||||||
|                                  score: 0.7939592 |           {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}, | ||||||
|                                  category_name: "cheeseburger" |           {/*index=*/925, /*score=*/0.019340655, | ||||||
|                                } |            /*category_name=*/"guacamole"}}, | ||||||
|                                categories { |       /*head_index=*/0, | ||||||
|                                  index: 932 |       /*head_name=*/"probability"}); | ||||||
|                                  score: 0.027392805 |   return result; | ||||||
|                                  category_name: "bagel" |  | ||||||
|                                } |  | ||||||
|                                categories { |  | ||||||
|                                  index: 925 |  | ||||||
|                                  score: 0.019340655 |  | ||||||
|                                  category_name: "guacamole" |  | ||||||
|                                } |  | ||||||
|                                timestamp_ms: %d |  | ||||||
|                              } |  | ||||||
|                              head_index: 0 |  | ||||||
|                              head_name: "probability" |  | ||||||
|                            })pb", |  | ||||||
|                       timestamp)); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Generates expected results for "multi_objects.jpg" using
 | // Generates expected results for "multi_objects.jpg" using
 | ||||||
| // kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
 | // kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
 | ||||||
| // box set around the soccer ball.
 | // box set around the soccer ball.
 | ||||||
| ClassificationResult GenerateSoccerBallResults(int64 timestamp) { | ImageClassifierResult GenerateSoccerBallResults() { | ||||||
|   return ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult result; | ||||||
|       absl::StrFormat(R"pb(classifications { |   result.classifications.emplace_back( | ||||||
|                              entries { |       Classifications{/*categories=*/{{/*index=*/806, /*score=*/0.996527493, | ||||||
|                                categories { |                                        /*category_name=*/"soccer ball"}}, | ||||||
|                                  index: 806 |                       /*head_index=*/0, | ||||||
|                                  score: 0.996527493 |                       /*head_name=*/"probability"}); | ||||||
|                                  category_name: "soccer ball" |   return result; | ||||||
|                                } |  | ||||||
|                                timestamp_ms: %d |  | ||||||
|                              } |  | ||||||
|                              head_index: 0 |  | ||||||
|                              head_name: "probability" |  | ||||||
|                            })pb", |  | ||||||
|                       timestamp)); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // A custom OpResolver only containing the Ops required by the test model.
 | // A custom OpResolver only containing the Ops required by the test model.
 | ||||||
|  | @ -260,7 +232,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) { | ||||||
|     options->base_options.model_asset_path = |     options->base_options.model_asset_path = | ||||||
|         JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); |         JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata); | ||||||
|     options->running_mode = running_mode; |     options->running_mode = running_mode; | ||||||
|     options->result_callback = [](absl::StatusOr<ClassificationResult>, |     options->result_callback = [](absl::StatusOr<ImageClassifierResult>, | ||||||
|                                   const Image& image, int64 timestamp_ms) {}; |                                   const Image& image, int64 timestamp_ms) {}; | ||||||
| 
 | 
 | ||||||
|     auto image_classifier = ImageClassifier::Create(std::move(options)); |     auto image_classifier = ImageClassifier::Create(std::move(options)); | ||||||
|  | @ -336,7 +308,7 @@ TEST_F(ImageModeTest, SucceedsWithFloatModel) { | ||||||
| 
 | 
 | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, GenerateBurgerResults(0)); |   ExpectApproximatelyEqual(results, GenerateBurgerResults()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithQuantizedModel) { | TEST_F(ImageModeTest, SucceedsWithQuantizedModel) { | ||||||
|  | @ -355,19 +327,13 @@ TEST_F(ImageModeTest, SucceedsWithQuantizedModel) { | ||||||
| 
 | 
 | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult expected; | ||||||
|                                         R"pb(classifications { |   expected.classifications.emplace_back( | ||||||
|                                                entries { |       Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.97265625, | ||||||
|                                                  categories { |                                        /*category_name=*/"cheeseburger"}}, | ||||||
|                                                    index: 934 |                       /*head_index=*/0, | ||||||
|                                                    score: 0.97265625 |                       /*head_name=*/"probability"}); | ||||||
|                                                    category_name: "cheeseburger" |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                                  } |  | ||||||
|                                                  timestamp_ms: 0 |  | ||||||
|                                                } |  | ||||||
|                                                head_index: 0 |  | ||||||
|                                                head_name: "probability" |  | ||||||
|                                              })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { | TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { | ||||||
|  | @ -383,19 +349,13 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { | ||||||
| 
 | 
 | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult expected; | ||||||
|                                         R"pb(classifications { |   expected.classifications.emplace_back( | ||||||
|                                                entries { |       Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.7939592, | ||||||
|                                                  categories { |                                        /*category_name=*/"cheeseburger"}}, | ||||||
|                                                    index: 934 |                       /*head_index=*/0, | ||||||
|                                                    score: 0.7939592 |                       /*head_name=*/"probability"}); | ||||||
|                                                    category_name: "cheeseburger" |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                                  } |  | ||||||
|                                                  timestamp_ms: 0 |  | ||||||
|                                                } |  | ||||||
|                                                head_index: 0 |  | ||||||
|                                                head_name: "probability" |  | ||||||
|                                              })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { | TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { | ||||||
|  | @ -411,24 +371,15 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { | ||||||
| 
 | 
 | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult expected; | ||||||
|                                         R"pb(classifications { |   expected.classifications.emplace_back(Classifications{ | ||||||
|                                                entries { |       /*categories=*/{ | ||||||
|                                                  categories { |           {/*index=*/934, /*score=*/0.7939592, | ||||||
|                                                    index: 934 |            /*category_name=*/"cheeseburger"}, | ||||||
|                                                    score: 0.7939592 |           {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}}, | ||||||
|                                                    category_name: "cheeseburger" |       /*head_index=*/0, | ||||||
|                                                  } |       /*head_name=*/"probability"}); | ||||||
|                                                  categories { |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                                    index: 932 |  | ||||||
|                                                    score: 0.027392805 |  | ||||||
|                                                    category_name: "bagel" |  | ||||||
|                                                  } |  | ||||||
|                                                  timestamp_ms: 0 |  | ||||||
|                                                } |  | ||||||
|                                                head_index: 0 |  | ||||||
|                                                head_name: "probability" |  | ||||||
|                                              })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { | TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { | ||||||
|  | @ -445,29 +396,17 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) { | ||||||
| 
 | 
 | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult expected; | ||||||
|                                         R"pb(classifications { |   expected.classifications.emplace_back(Classifications{ | ||||||
|                                                entries { |       /*categories=*/{ | ||||||
|                                                  categories { |           {/*index=*/934, /*score=*/0.7939592, | ||||||
|                                                    index: 934 |            /*category_name=*/"cheeseburger"}, | ||||||
|                                                    score: 0.7939592 |           {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"}, | ||||||
|                                                    category_name: "cheeseburger" |           {/*index=*/963, /*score=*/0.0063278517, | ||||||
|                                                  } |            /*category_name=*/"meat loaf"}}, | ||||||
|                                                  categories { |       /*head_index=*/0, | ||||||
|                                                    index: 925 |       /*head_name=*/"probability"}); | ||||||
|                                                    score: 0.019340655 |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                                    category_name: "guacamole" |  | ||||||
|                                                  } |  | ||||||
|                                                  categories { |  | ||||||
|                                                    index: 963 |  | ||||||
|                                                    score: 0.0063278517 |  | ||||||
|                                                    category_name: "meat loaf" |  | ||||||
|                                                  } |  | ||||||
|                                                  timestamp_ms: 0 |  | ||||||
|                                                } |  | ||||||
|                                                head_index: 0 |  | ||||||
|                                                head_name: "probability" |  | ||||||
|                                              })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithDenylistOption) { | TEST_F(ImageModeTest, SucceedsWithDenylistOption) { | ||||||
|  | @ -484,29 +423,17 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) { | ||||||
| 
 | 
 | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult expected; | ||||||
|                                         R"pb(classifications { |   expected.classifications.emplace_back(Classifications{ | ||||||
|                                                entries { |       /*categories=*/{ | ||||||
|                                                  categories { |           {/*index=*/934, /*score=*/0.7939592, | ||||||
|                                                    index: 934 |            /*category_name=*/"cheeseburger"}, | ||||||
|                                                    score: 0.7939592 |           {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"}, | ||||||
|                                                    category_name: "cheeseburger" |           {/*index=*/963, /*score=*/0.0063278517, | ||||||
|                                                  } |            /*category_name=*/"meat loaf"}}, | ||||||
|                                                  categories { |       /*head_index=*/0, | ||||||
|                                                    index: 925 |       /*head_name=*/"probability"}); | ||||||
|                                                    score: 0.019340655 |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                                    category_name: "guacamole" |  | ||||||
|                                                  } |  | ||||||
|                                                  categories { |  | ||||||
|                                                    index: 963 |  | ||||||
|                                                    score: 0.0063278517 |  | ||||||
|                                                    category_name: "meat loaf" |  | ||||||
|                                                  } |  | ||||||
|                                                  timestamp_ms: 0 |  | ||||||
|                                                } |  | ||||||
|                                                head_index: 0 |  | ||||||
|                                                head_name: "probability" |  | ||||||
|                                              })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { | TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { | ||||||
|  | @ -525,19 +452,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { | ||||||
| 
 | 
 | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult expected; | ||||||
|                                         R"pb(classifications { |   expected.classifications.emplace_back( | ||||||
|                                                entries { |       Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.725648628, | ||||||
|                                                  categories { |                                        /*category_name=*/"cheeseburger"}}, | ||||||
|                                                    index: 934 |                       /*head_index=*/0, | ||||||
|                                                    score: 0.725648628 |                       /*head_name=*/"probability"}); | ||||||
|                                                    category_name: "cheeseburger" |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                                  } |  | ||||||
|                                                  timestamp_ms: 0 |  | ||||||
|                                                } |  | ||||||
|                                                head_index: 0 |  | ||||||
|                                                head_name: "probability" |  | ||||||
|                                              })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { | TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { | ||||||
|  | @ -557,7 +478,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( | ||||||
|                                             image, image_processing_options)); |                                             image, image_processing_options)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0)); |   ExpectApproximatelyEqual(results, GenerateSoccerBallResults()); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithRotation) { | TEST_F(ImageModeTest, SucceedsWithRotation) { | ||||||
|  | @ -581,29 +502,17 @@ TEST_F(ImageModeTest, SucceedsWithRotation) { | ||||||
|   // Results differ slightly from the non-rotated image, but that's expected
 |   // Results differ slightly from the non-rotated image, but that's expected
 | ||||||
|   // as models are very sensitive to the slightest numerical differences
 |   // as models are very sensitive to the slightest numerical differences
 | ||||||
|   // introduced by the rotation and JPG encoding.
 |   // introduced by the rotation and JPG encoding.
 | ||||||
|   ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>( |   ImageClassifierResult expected; | ||||||
|                                         R"pb(classifications { |   expected.classifications.emplace_back(Classifications{ | ||||||
|                                                entries { |       /*categories=*/{ | ||||||
|                                                  categories { |           {/*index=*/934, /*score=*/0.6371766, | ||||||
|                                                    index: 934 |            /*category_name=*/"cheeseburger"}, | ||||||
|                                                    score: 0.6371766 |           {/*index=*/963, /*score=*/0.049443405, /*category_name=*/"meat loaf"}, | ||||||
|                                                    category_name: "cheeseburger" |           {/*index=*/925, /*score=*/0.047918003, | ||||||
|                                                  } |            /*category_name=*/"guacamole"}}, | ||||||
|                                                  categories { |       /*head_index=*/0, | ||||||
|                                                    index: 963 |       /*head_name=*/"probability"}); | ||||||
|                                                    score: 0.049443405 |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                                    category_name: "meat loaf" |  | ||||||
|                                                  } |  | ||||||
|                                                  categories { |  | ||||||
|                                                    index: 925 |  | ||||||
|                                                    score: 0.047918003 |  | ||||||
|                                                    category_name: "guacamole" |  | ||||||
|                                                  } |  | ||||||
|                                                  timestamp_ms: 0 |  | ||||||
|                                                } |  | ||||||
|                                                head_index: 0 |  | ||||||
|                                                head_name: "probability" |  | ||||||
|                                              })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { | TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { | ||||||
|  | @ -624,20 +533,13 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( |   MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify( | ||||||
|                                             image, image_processing_options)); |                                             image, image_processing_options)); | ||||||
| 
 | 
 | ||||||
|   ExpectApproximatelyEqual(results, |   ImageClassifierResult expected; | ||||||
|                            ParseTextProtoOrDie<ClassificationResult>( |   expected.classifications.emplace_back( | ||||||
|                                R"pb(classifications { |       Classifications{/*categories=*/{{/*index=*/560, /*score=*/0.6522213, | ||||||
|                                       entries { |                                        /*category_name=*/"folding chair"}}, | ||||||
|                                         categories { |                       /*head_index=*/0, | ||||||
|                                           index: 560 |                       /*head_name=*/"probability"}); | ||||||
|                                           score: 0.6522213 |   ExpectApproximatelyEqual(results, expected); | ||||||
|                                           category_name: "folding chair" |  | ||||||
|                                         } |  | ||||||
|                                         timestamp_ms: 0 |  | ||||||
|                                       } |  | ||||||
|                                       head_index: 0 |  | ||||||
|                                       head_name: "probability" |  | ||||||
|                                     })pb")); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Testing all these once with ImageClassifier.
 | // Testing all these once with ImageClassifier.
 | ||||||
|  | @ -774,7 +676,7 @@ TEST_F(VideoModeTest, Succeeds) { | ||||||
|   for (int i = 0; i < iterations; ++i) { |   for (int i = 0; i < iterations; ++i) { | ||||||
|     MP_ASSERT_OK_AND_ASSIGN(auto results, |     MP_ASSERT_OK_AND_ASSIGN(auto results, | ||||||
|                             image_classifier->ClassifyForVideo(image, i)); |                             image_classifier->ClassifyForVideo(image, i)); | ||||||
|     ExpectApproximatelyEqual(results, GenerateBurgerResults(i)); |     ExpectApproximatelyEqual(results, GenerateBurgerResults()); | ||||||
|   } |   } | ||||||
|   MP_ASSERT_OK(image_classifier->Close()); |   MP_ASSERT_OK(image_classifier->Close()); | ||||||
| } | } | ||||||
|  | @ -800,7 +702,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) { | ||||||
|     MP_ASSERT_OK_AND_ASSIGN( |     MP_ASSERT_OK_AND_ASSIGN( | ||||||
|         auto results, |         auto results, | ||||||
|         image_classifier->ClassifyForVideo(image, i, image_processing_options)); |         image_classifier->ClassifyForVideo(image, i, image_processing_options)); | ||||||
|     ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i)); |     ExpectApproximatelyEqual(results, GenerateSoccerBallResults()); | ||||||
|   } |   } | ||||||
|   MP_ASSERT_OK(image_classifier->Close()); |   MP_ASSERT_OK(image_classifier->Close()); | ||||||
| } | } | ||||||
|  | @ -815,7 +717,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { | ||||||
|   options->base_options.model_asset_path = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); |       JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); | ||||||
|   options->running_mode = core::RunningMode::LIVE_STREAM; |   options->running_mode = core::RunningMode::LIVE_STREAM; | ||||||
|   options->result_callback = [](absl::StatusOr<ClassificationResult>, |   options->result_callback = [](absl::StatusOr<ImageClassifierResult>, | ||||||
|                                 const Image& image, int64 timestamp_ms) {}; |                                 const Image& image, int64 timestamp_ms) {}; | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, | ||||||
|                           ImageClassifier::Create(std::move(options))); |                           ImageClassifier::Create(std::move(options))); | ||||||
|  | @ -846,7 +748,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { | ||||||
|   options->base_options.model_asset_path = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); |       JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata); | ||||||
|   options->running_mode = core::RunningMode::LIVE_STREAM; |   options->running_mode = core::RunningMode::LIVE_STREAM; | ||||||
|   options->result_callback = [](absl::StatusOr<ClassificationResult>, |   options->result_callback = [](absl::StatusOr<ImageClassifierResult>, | ||||||
|                                 const Image& image, int64 timestamp_ms) {}; |                                 const Image& image, int64 timestamp_ms) {}; | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier, | ||||||
|                           ImageClassifier::Create(std::move(options))); |                           ImageClassifier::Create(std::move(options))); | ||||||
|  | @ -864,7 +766,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct LiveStreamModeResults { | struct LiveStreamModeResults { | ||||||
|   ClassificationResult classification_result; |   ImageClassifierResult classification_result; | ||||||
|   std::pair<int, int> image_size; |   std::pair<int, int> image_size; | ||||||
|   int64 timestamp_ms; |   int64 timestamp_ms; | ||||||
| }; | }; | ||||||
|  | @ -881,7 +783,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { | ||||||
|   options->running_mode = core::RunningMode::LIVE_STREAM; |   options->running_mode = core::RunningMode::LIVE_STREAM; | ||||||
|   options->classifier_options.max_results = 3; |   options->classifier_options.max_results = 3; | ||||||
|   options->result_callback = |   options->result_callback = | ||||||
|       [&results](absl::StatusOr<ClassificationResult> classification_result, |       [&results](absl::StatusOr<ImageClassifierResult> classification_result, | ||||||
|                  const Image& image, int64 timestamp_ms) { |                  const Image& image, int64 timestamp_ms) { | ||||||
|         MP_ASSERT_OK(classification_result.status()); |         MP_ASSERT_OK(classification_result.status()); | ||||||
|         results.push_back( |         results.push_back( | ||||||
|  | @ -908,7 +810,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { | ||||||
|     EXPECT_EQ(result.image_size.first, image.width()); |     EXPECT_EQ(result.image_size.first, image.width()); | ||||||
|     EXPECT_EQ(result.image_size.second, image.height()); |     EXPECT_EQ(result.image_size.second, image.height()); | ||||||
|     ExpectApproximatelyEqual(result.classification_result, |     ExpectApproximatelyEqual(result.classification_result, | ||||||
|                              GenerateBurgerResults(timestamp_ms)); |                              GenerateBurgerResults()); | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -924,7 +826,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { | ||||||
|   options->running_mode = core::RunningMode::LIVE_STREAM; |   options->running_mode = core::RunningMode::LIVE_STREAM; | ||||||
|   options->classifier_options.max_results = 1; |   options->classifier_options.max_results = 1; | ||||||
|   options->result_callback = |   options->result_callback = | ||||||
|       [&results](absl::StatusOr<ClassificationResult> classification_result, |       [&results](absl::StatusOr<ImageClassifierResult> classification_result, | ||||||
|                  const Image& image, int64 timestamp_ms) { |                  const Image& image, int64 timestamp_ms) { | ||||||
|         MP_ASSERT_OK(classification_result.status()); |         MP_ASSERT_OK(classification_result.status()); | ||||||
|         results.push_back( |         results.push_back( | ||||||
|  | @ -955,7 +857,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) { | ||||||
|     EXPECT_EQ(result.image_size.first, image.width()); |     EXPECT_EQ(result.image_size.first, image.width()); | ||||||
|     EXPECT_EQ(result.image_size.second, image.height()); |     EXPECT_EQ(result.image_size.second, image.height()); | ||||||
|     ExpectApproximatelyEqual(result.classification_result, |     ExpectApproximatelyEqual(result.classification_result, | ||||||
|                              GenerateSoccerBallResults(timestamp_ms)); |                              GenerateSoccerBallResults()); | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user