Add classification result structs and use then in ImageClassifier & TextClassifier.
PiperOrigin-RevId: 485567133
This commit is contained in:
		
							parent
							
								
									aab5f84aae
								
							
						
					
					
						commit
						d61ab92b90
					
				| 
						 | 
				
			
			@ -29,3 +29,23 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/formats:landmark_cc_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "category",
 | 
			
		||||
    srcs = ["category.cc"],
 | 
			
		||||
    hdrs = ["category.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "classification_result",
 | 
			
		||||
    srcs = ["classification_result.cc"],
 | 
			
		||||
    hdrs = ["classification_result.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":category",
 | 
			
		||||
        "//mediapipe/framework/formats:classification_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										38
									
								
								mediapipe/tasks/cc/components/containers/category.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								mediapipe/tasks/cc/components/containers/category.cc
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,38 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
Category ConvertToCategory(const mediapipe::Classification& proto) {
 | 
			
		||||
  Category category;
 | 
			
		||||
  category.index = proto.index();
 | 
			
		||||
  category.score = proto.score();
 | 
			
		||||
  if (proto.has_label()) {
 | 
			
		||||
    category.category_name = proto.label();
 | 
			
		||||
  }
 | 
			
		||||
  if (proto.has_display_name()) {
 | 
			
		||||
    category.display_name = proto.display_name();
 | 
			
		||||
  }
 | 
			
		||||
  return category;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
							
								
								
									
										52
									
								
								mediapipe/tasks/cc/components/containers/category.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								mediapipe/tasks/cc/components/containers/category.h
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,52 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
// Defines a single classification result.
 | 
			
		||||
//
 | 
			
		||||
// The label maps packed into the TFLite Model Metadata [1] are used to populate
 | 
			
		||||
// the 'category_name' and 'display_name' fields.
 | 
			
		||||
//
 | 
			
		||||
// [1]: https://www.tensorflow.org/lite/convert/metadata
 | 
			
		||||
struct Category {
 | 
			
		||||
  // The index of the category in the classification model output.
 | 
			
		||||
  int index;
 | 
			
		||||
  // The score for this category, e.g. (but not necessarily) a probability in
 | 
			
		||||
  // [0,1].
 | 
			
		||||
  float score;
 | 
			
		||||
  // The optional ID for the category, read from the label map packed in the
 | 
			
		||||
  // TFLite Model Metadata if present. Not necessarily human-readable.
 | 
			
		||||
  std::optional<std::string> category_name = std::nullopt;
 | 
			
		||||
  // The optional human-readable name for the category, read from the label map
 | 
			
		||||
  // packed in the TFLite Model Metadata if present.
 | 
			
		||||
  std::optional<std::string> display_name = std::nullopt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Utility function to convert from mediapipe::Classification proto to Category
 | 
			
		||||
// struct.
 | 
			
		||||
Category ConvertToCategory(const mediapipe::Classification& proto);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CATEGORY_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,57 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/framework/formats/classification.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
Classifications ConvertToClassifications(const proto::Classifications& proto) {
 | 
			
		||||
  Classifications classifications;
 | 
			
		||||
  classifications.categories.reserve(
 | 
			
		||||
      proto.classification_list().classification_size());
 | 
			
		||||
  for (const auto& classification :
 | 
			
		||||
       proto.classification_list().classification()) {
 | 
			
		||||
    classifications.categories.push_back(ConvertToCategory(classification));
 | 
			
		||||
  }
 | 
			
		||||
  classifications.head_index = proto.head_index();
 | 
			
		||||
  if (proto.has_head_name()) {
 | 
			
		||||
    classifications.head_name = proto.head_name();
 | 
			
		||||
  }
 | 
			
		||||
  return classifications;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ClassificationResult ConvertToClassificationResult(
 | 
			
		||||
    const proto::ClassificationResult& proto) {
 | 
			
		||||
  ClassificationResult classification_result;
 | 
			
		||||
  classification_result.classifications.reserve(proto.classifications_size());
 | 
			
		||||
  for (const auto& classifications : proto.classifications()) {
 | 
			
		||||
    classification_result.classifications.push_back(
 | 
			
		||||
        ConvertToClassifications(classifications));
 | 
			
		||||
  }
 | 
			
		||||
  if (proto.has_timestamp_ms()) {
 | 
			
		||||
    classification_result.timestamp_ms = proto.timestamp_ms();
 | 
			
		||||
  }
 | 
			
		||||
  return classification_result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,68 @@
 | 
			
		|||
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
 | 
			
		||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
 | 
			
		||||
namespace mediapipe::tasks::components::containers {
 | 
			
		||||
 | 
			
		||||
// Defines classification results for a given classifier head.
 | 
			
		||||
struct Classifications {
 | 
			
		||||
  // The array of predicted categories, usually sorted by descending scores,
 | 
			
		||||
  // e.g. from high to low probability.
 | 
			
		||||
  std::vector<Category> categories;
 | 
			
		||||
  // The index of the classifier head (i.e. output tensor) these categories
 | 
			
		||||
  // refer to. This is useful for multi-head models.
 | 
			
		||||
  int head_index;
 | 
			
		||||
  // The optional name of the classifier head, as provided in the TFLite Model
 | 
			
		||||
  // Metadata [1] if present. This is useful for multi-head models.
 | 
			
		||||
  //
 | 
			
		||||
  // [1]: https://www.tensorflow.org/lite/convert/metadata
 | 
			
		||||
  std::optional<std::string> head_name = std::nullopt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Defines classification results of a model.
 | 
			
		||||
struct ClassificationResult {
 | 
			
		||||
  // The classification results for each head of the model.
 | 
			
		||||
  std::vector<Classifications> classifications;
 | 
			
		||||
  // The optional timestamp (in milliseconds) of the start of the chunk of data
 | 
			
		||||
  // corresponding to these results.
 | 
			
		||||
  //
 | 
			
		||||
  // This is only used for classification on time series (e.g. audio
 | 
			
		||||
  // classification). In these use cases, the amount of data to process might
 | 
			
		||||
  // exceed the maximum size that the model can process: to solve this, the
 | 
			
		||||
  // input data is split into multiple chunks starting at different timestamps.
 | 
			
		||||
  std::optional<int64_t> timestamp_ms = std::nullopt;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Utility function to convert from Classifications proto to
 | 
			
		||||
// Classifications struct.
 | 
			
		||||
Classifications ConvertToClassifications(const proto::Classifications& proto);
 | 
			
		||||
 | 
			
		||||
// Utility function to convert from ClassificationResult proto to
 | 
			
		||||
// ClassificationResult struct.
 | 
			
		||||
ClassificationResult ConvertToClassificationResult(
 | 
			
		||||
    const proto::ClassificationResult& proto);
 | 
			
		||||
 | 
			
		||||
}  // namespace mediapipe::tasks::components::containers
 | 
			
		||||
 | 
			
		||||
#endif  // MEDIAPIPE_TASKS_CC_COMPONENTS_CONTAINERS_CLASSIFICATION_RESULT_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -49,6 +49,8 @@ cc_library(
 | 
			
		|||
        ":text_classifier_graph",
 | 
			
		||||
        "//mediapipe/framework:packet",
 | 
			
		||||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:category",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:classification_result",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			@ -76,7 +78,8 @@ cc_test(
 | 
			
		|||
        "//mediapipe/framework/deps:file_path",
 | 
			
		||||
        "//mediapipe/framework/port:gtest_main",
 | 
			
		||||
        "//mediapipe/tasks/cc:common",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:category",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:classification_result",
 | 
			
		||||
        "@com_google_absl//absl/flags:flag",
 | 
			
		||||
        "@com_google_absl//absl/status",
 | 
			
		||||
        "@com_google_absl//absl/status:statusor",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -24,6 +24,7 @@ limitations under the License.
 | 
			
		|||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/framework/api2/builder.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/task_api_factory.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -37,12 +38,13 @@ namespace text_classifier {
 | 
			
		|||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		||||
 | 
			
		||||
constexpr char kTextStreamName[] = "text_in";
 | 
			
		||||
constexpr char kTextTag[] = "TEXT";
 | 
			
		||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
 | 
			
		||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
 | 
			
		||||
constexpr char kClassificationsStreamName[] = "classifications_out";
 | 
			
		||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
			
		||||
constexpr char kSubgraphTypeName[] =
 | 
			
		||||
    "mediapipe.tasks.text.text_classifier.TextClassifierGraph";
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -54,9 +56,8 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
			
		|||
  auto& subgraph = graph.AddNode(kSubgraphTypeName);
 | 
			
		||||
  subgraph.GetOptions<proto::TextClassifierGraphOptions>().Swap(options.get());
 | 
			
		||||
  graph.In(kTextTag).SetName(kTextStreamName) >> subgraph.In(kTextTag);
 | 
			
		||||
  subgraph.Out(kClassificationResultTag)
 | 
			
		||||
          .SetName(kClassificationResultStreamName) >>
 | 
			
		||||
      graph.Out(kClassificationResultTag);
 | 
			
		||||
  subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
 | 
			
		||||
      graph.Out(kClassificationsTag);
 | 
			
		||||
  return graph.GetConfig();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -88,14 +89,14 @@ absl::StatusOr<std::unique_ptr<TextClassifier>> TextClassifier::Create(
 | 
			
		|||
      std::move(options->base_options.op_resolver));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<ClassificationResult> TextClassifier::Classify(
 | 
			
		||||
absl::StatusOr<TextClassifierResult> TextClassifier::Classify(
 | 
			
		||||
    absl::string_view text) {
 | 
			
		||||
  ASSIGN_OR_RETURN(
 | 
			
		||||
      auto output_packets,
 | 
			
		||||
      runner_->Process(
 | 
			
		||||
          {{kTextStreamName, MakePacket<std::string>(std::string(text))}}));
 | 
			
		||||
  return output_packets[kClassificationResultStreamName]
 | 
			
		||||
      .Get<ClassificationResult>();
 | 
			
		||||
  return ConvertToClassificationResult(
 | 
			
		||||
      output_packets[kClassificationsStreamName].Get<ClassificationResult>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace text_classifier
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -21,7 +21,7 @@ limitations under the License.
 | 
			
		|||
#include "absl/status/status.h"
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/base_task_api.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -31,6 +31,10 @@ namespace tasks {
 | 
			
		|||
namespace text {
 | 
			
		||||
namespace text_classifier {
 | 
			
		||||
 | 
			
		||||
// Alias the shared ClassificationResult struct as result type.
 | 
			
		||||
using TextClassifierResult =
 | 
			
		||||
    ::mediapipe::tasks::components::containers::ClassificationResult;
 | 
			
		||||
 | 
			
		||||
// The options for configuring a MediaPipe text classifier task.
 | 
			
		||||
struct TextClassifierOptions {
 | 
			
		||||
  // Base options for configuring MediaPipe Tasks, such as specifying the model
 | 
			
		||||
| 
						 | 
				
			
			@ -81,8 +85,7 @@ class TextClassifier : core::BaseTaskApi {
 | 
			
		|||
      std::unique_ptr<TextClassifierOptions> options);
 | 
			
		||||
 | 
			
		||||
  // Performs classification on the input `text`.
 | 
			
		||||
  absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
 | 
			
		||||
      absl::string_view text);
 | 
			
		||||
  absl::StatusOr<TextClassifierResult> Classify(absl::string_view text);
 | 
			
		||||
 | 
			
		||||
  // Shuts down the TextClassifier when all the work is done.
 | 
			
		||||
  absl::Status Close() { return runner_->Close(); }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -47,10 +47,18 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		|||
using ::mediapipe::tasks::core::ModelResources;
 | 
			
		||||
 | 
			
		||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
 | 
			
		||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
			
		||||
constexpr char kTextTag[] = "TEXT";
 | 
			
		||||
constexpr char kMetadataExtractorTag[] = "METADATA_EXTRACTOR";
 | 
			
		||||
constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		||||
 | 
			
		||||
// TODO: remove once Java API migration is over.
 | 
			
		||||
// Struct holding the different output streams produced by the text classifier.
 | 
			
		||||
struct TextClassifierOutputStreams {
 | 
			
		||||
  Source<ClassificationResult> classification_result;
 | 
			
		||||
  Source<ClassificationResult> classifications;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
// A "TextClassifierGraph" performs Natural Language classification (including
 | 
			
		||||
| 
						 | 
				
			
			@ -62,7 +70,10 @@ constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		|||
//     Input text to perform classification on.
 | 
			
		||||
//
 | 
			
		||||
// Outputs:
 | 
			
		||||
//   CLASSIFICATION_RESULT - ClassificationResult
 | 
			
		||||
//   CLASSIFICATIONS - ClassificationResult @Optional
 | 
			
		||||
//     The classification results aggregated by classifier head.
 | 
			
		||||
// TODO: remove once Java API migration is over.
 | 
			
		||||
//   CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
 | 
			
		||||
//     The aggregated classification result object that has 3 dimensions:
 | 
			
		||||
//     (classification head, classification timestamp, classification category).
 | 
			
		||||
//
 | 
			
		||||
| 
						 | 
				
			
			@ -70,7 +81,7 @@ constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		|||
// node {
 | 
			
		||||
//   calculator: "mediapipe.tasks.text.text_classifier.TextClassifierGraph"
 | 
			
		||||
//   input_stream: "TEXT:text_in"
 | 
			
		||||
//   output_stream: "CLASSIFICATION_RESULT:classification_result_out"
 | 
			
		||||
//   output_stream: "CLASSIFICATIONS:classifications_out"
 | 
			
		||||
//   options {
 | 
			
		||||
//     [mediapipe.tasks.text.text_classifier.proto.TextClassifierGraphOptions.ext]
 | 
			
		||||
//     {
 | 
			
		||||
| 
						 | 
				
			
			@ -91,12 +102,14 @@ class TextClassifierGraph : public core::ModelTaskGraph {
 | 
			
		|||
        CreateModelResources<proto::TextClassifierGraphOptions>(sc));
 | 
			
		||||
    Graph graph;
 | 
			
		||||
    ASSIGN_OR_RETURN(
 | 
			
		||||
        Source<ClassificationResult> classification_result_out,
 | 
			
		||||
        auto output_streams,
 | 
			
		||||
        BuildTextClassifierTask(
 | 
			
		||||
            sc->Options<proto::TextClassifierGraphOptions>(), *model_resources,
 | 
			
		||||
            graph[Input<std::string>(kTextTag)], graph));
 | 
			
		||||
    classification_result_out >>
 | 
			
		||||
    output_streams.classification_result >>
 | 
			
		||||
        graph[Output<ClassificationResult>(kClassificationResultTag)];
 | 
			
		||||
    output_streams.classifications >>
 | 
			
		||||
        graph[Output<ClassificationResult>(kClassificationsTag)];
 | 
			
		||||
    return graph.GetConfig();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -111,7 +124,7 @@ class TextClassifierGraph : public core::ModelTaskGraph {
 | 
			
		|||
  //   TextClassifier model file with model metadata.
 | 
			
		||||
  // text_in: (std::string) stream to run text classification on.
 | 
			
		||||
  // graph: the mediapipe builder::Graph instance to be updated.
 | 
			
		||||
  absl::StatusOr<Source<ClassificationResult>> BuildTextClassifierTask(
 | 
			
		||||
  absl::StatusOr<TextClassifierOutputStreams> BuildTextClassifierTask(
 | 
			
		||||
      const proto::TextClassifierGraphOptions& task_options,
 | 
			
		||||
      const ModelResources& model_resources, Source<std::string> text_in,
 | 
			
		||||
      Graph& graph) {
 | 
			
		||||
| 
						 | 
				
			
			@ -148,8 +161,11 @@ class TextClassifierGraph : public core::ModelTaskGraph {
 | 
			
		|||
 | 
			
		||||
    // Outputs the aggregated classification result as the subgraph output
 | 
			
		||||
    // stream.
 | 
			
		||||
    return postprocessing[Output<ClassificationResult>(
 | 
			
		||||
        kClassificationResultTag)];
 | 
			
		||||
    return TextClassifierOutputStreams{
 | 
			
		||||
        /*classification_result=*/postprocessing[Output<ClassificationResult>(
 | 
			
		||||
            kClassificationResultTag)],
 | 
			
		||||
        /*classifications=*/postprocessing[Output<ClassificationResult>(
 | 
			
		||||
            kClassificationsTag)]};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,7 +33,8 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/framework/port/gtest.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/common.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/text/text_classifier/text_classifier_test_utils.h"
 | 
			
		||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -43,14 +44,13 @@ namespace text {
 | 
			
		|||
namespace text_classifier {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::EqualsProto;
 | 
			
		||||
using ::mediapipe::file::JoinPath;
 | 
			
		||||
using ::mediapipe::tasks::kMediaPipeTasksPayload;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::Category;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::Classifications;
 | 
			
		||||
using ::testing::HasSubstr;
 | 
			
		||||
using ::testing::Optional;
 | 
			
		||||
 | 
			
		||||
constexpr float kEpsilon = 0.001;
 | 
			
		||||
constexpr int kMaxSeqLen = 128;
 | 
			
		||||
constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/text/";
 | 
			
		||||
constexpr char kTestBertModelPath[] = "bert_text_classifier.tflite";
 | 
			
		||||
| 
						 | 
				
			
			@ -64,6 +64,30 @@ std::string GetFullPath(absl::string_view file_name) {
 | 
			
		|||
  return JoinPath("./", kTestDataDirectory, file_name);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Checks that the two provided `TextClassifierResult` are equal, with a
 | 
			
		||||
// tolerancy on floating-point score to account for numerical instabilities.
 | 
			
		||||
// TODO: create shared matcher for ClassificationResult.
 | 
			
		||||
void ExpectApproximatelyEqual(const TextClassifierResult& actual,
 | 
			
		||||
                              const TextClassifierResult& expected) {
 | 
			
		||||
  const float kPrecision = 1e-6;
 | 
			
		||||
  ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
 | 
			
		||||
  for (int i = 0; i < actual.classifications.size(); ++i) {
 | 
			
		||||
    const Classifications& a = actual.classifications[i];
 | 
			
		||||
    const Classifications& b = expected.classifications[i];
 | 
			
		||||
    EXPECT_EQ(a.head_index, b.head_index);
 | 
			
		||||
    EXPECT_EQ(a.head_name, b.head_name);
 | 
			
		||||
    EXPECT_EQ(a.categories.size(), b.categories.size());
 | 
			
		||||
    for (int j = 0; j < a.categories.size(); ++j) {
 | 
			
		||||
      const Category& x = a.categories[j];
 | 
			
		||||
      const Category& y = b.categories[j];
 | 
			
		||||
      EXPECT_EQ(x.index, y.index);
 | 
			
		||||
      EXPECT_NEAR(x.score, y.score, kPrecision);
 | 
			
		||||
      EXPECT_EQ(x.category_name, y.category_name);
 | 
			
		||||
      EXPECT_EQ(x.display_name, y.display_name);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class TextClassifierTest : public tflite_shims::testing::Test {};
 | 
			
		||||
 | 
			
		||||
TEST_F(TextClassifierTest, CreateSucceedsWithBertModel) {
 | 
			
		||||
| 
						 | 
				
			
			@ -107,6 +131,92 @@ TEST_F(TextClassifierTest, CreateSucceedsWithRegexModel) {
 | 
			
		|||
  MP_ASSERT_OK(TextClassifier::Create(std::move(options)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TextClassifierTest, TextClassifierWithBert) {
 | 
			
		||||
  auto options = std::make_unique<TextClassifierOptions>();
 | 
			
		||||
  options->base_options.model_asset_path = GetFullPath(kTestBertModelPath);
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
 | 
			
		||||
                          TextClassifier::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      TextClassifierResult negative_result,
 | 
			
		||||
      classifier->Classify("unflinchingly bleak and desperate"));
 | 
			
		||||
  TextClassifierResult negative_expected;
 | 
			
		||||
  negative_expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/0, /*score=*/0.956317, /*category_name=*/"negative"},
 | 
			
		||||
          {/*index=*/1, /*score=*/0.043683, /*category_name=*/"positive"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(negative_result, negative_expected);
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      TextClassifierResult positive_result,
 | 
			
		||||
      classifier->Classify("it's a charming and often affecting journey"));
 | 
			
		||||
  TextClassifierResult positive_expected;
 | 
			
		||||
  positive_expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/1, /*score=*/0.999945, /*category_name=*/"positive"},
 | 
			
		||||
          {/*index=*/0, /*score=*/0.000056, /*category_name=*/"negative"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(positive_result, positive_expected);
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(classifier->Close());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TextClassifierTest, TextClassifierWithIntInputs) {
 | 
			
		||||
  auto options = std::make_unique<TextClassifierOptions>();
 | 
			
		||||
  options->base_options.model_asset_path = GetFullPath(kTestRegexModelPath);
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
 | 
			
		||||
                          TextClassifier::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult negative_result,
 | 
			
		||||
                          classifier->Classify("What a waste of my time."));
 | 
			
		||||
  TextClassifierResult negative_expected;
 | 
			
		||||
  negative_expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/0, /*score=*/0.813130, /*category_name=*/"Negative"},
 | 
			
		||||
          {/*index=*/1, /*score=*/0.186870, /*category_name=*/"Positive"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(negative_result, negative_expected);
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
      TextClassifierResult positive_result,
 | 
			
		||||
      classifier->Classify("This is the best movie I’ve seen in recent years."
 | 
			
		||||
                           "Strongly recommend it!"));
 | 
			
		||||
  TextClassifierResult positive_expected;
 | 
			
		||||
  positive_expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/1, /*score=*/0.513427, /*category_name=*/"Positive"},
 | 
			
		||||
          {/*index=*/0, /*score=*/0.486573, /*category_name=*/"Negative"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(positive_result, positive_expected);
 | 
			
		||||
 | 
			
		||||
  MP_ASSERT_OK(classifier->Close());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(TextClassifierTest, TextClassifierWithStringToBool) {
 | 
			
		||||
  auto options = std::make_unique<TextClassifierOptions>();
 | 
			
		||||
  options->base_options.model_asset_path = GetFullPath(kStringToBoolModelPath);
 | 
			
		||||
  options->base_options.op_resolver = CreateCustomResolver();
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<TextClassifier> classifier,
 | 
			
		||||
                          TextClassifier::Create(std::move(options)));
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(TextClassifierResult result,
 | 
			
		||||
                          classifier->Classify("hello"));
 | 
			
		||||
 | 
			
		||||
  // Binary outputs causes flaky ordering, so we compare manually.
 | 
			
		||||
  ASSERT_EQ(result.classifications.size(), 1);
 | 
			
		||||
  ASSERT_EQ(result.classifications[0].head_index, 0);
 | 
			
		||||
  ASSERT_EQ(result.classifications[0].categories.size(), 3);
 | 
			
		||||
  ASSERT_EQ(result.classifications[0].categories[0].score, 1);
 | 
			
		||||
  ASSERT_LT(result.classifications[0].categories[0].index, 2);  // i.e O or 1.
 | 
			
		||||
  ASSERT_EQ(result.classifications[0].categories[1].score, 1);
 | 
			
		||||
  ASSERT_LT(result.classifications[0].categories[1].index, 2);  // i.e 0 or 1.
 | 
			
		||||
  ASSERT_EQ(result.classifications[0].categories[2].score, 0);
 | 
			
		||||
  ASSERT_EQ(result.classifications[0].categories[2].index, 2);
 | 
			
		||||
  MP_ASSERT_OK(classifier->Close());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace text_classifier
 | 
			
		||||
}  // namespace text
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,6 +50,7 @@ cc_library(
 | 
			
		|||
        "//mediapipe/framework/api2:builder",
 | 
			
		||||
        "//mediapipe/framework/formats:image",
 | 
			
		||||
        "//mediapipe/framework/formats:rect_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers:classification_result",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors:classifier_options",
 | 
			
		||||
        "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -26,6 +26,7 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/framework/formats/rect.pb.h"
 | 
			
		||||
#include "mediapipe/framework/packet.h"
 | 
			
		||||
#include "mediapipe/framework/timestamp.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -46,8 +47,8 @@ namespace image_classifier {
 | 
			
		|||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
 | 
			
		||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
 | 
			
		||||
constexpr char kClassificationsStreamName[] = "classifications_out";
 | 
			
		||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
			
		||||
constexpr char kImageInStreamName[] = "image_in";
 | 
			
		||||
constexpr char kImageOutStreamName[] = "image_out";
 | 
			
		||||
constexpr char kImageTag[] = "IMAGE";
 | 
			
		||||
| 
						 | 
				
			
			@ -57,6 +58,7 @@ constexpr char kSubgraphTypeName[] =
 | 
			
		|||
    "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
 | 
			
		||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::tasks::components::containers::ConvertToClassificationResult;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		||||
using ::mediapipe::tasks::core::PacketMap;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -73,15 +75,13 @@ CalculatorGraphConfig CreateGraphConfig(
 | 
			
		|||
  auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
 | 
			
		||||
  task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
 | 
			
		||||
      options_proto.get());
 | 
			
		||||
  task_subgraph.Out(kClassificationResultTag)
 | 
			
		||||
          .SetName(kClassificationResultStreamName) >>
 | 
			
		||||
      graph.Out(kClassificationResultTag);
 | 
			
		||||
  task_subgraph.Out(kClassificationsTag).SetName(kClassificationsStreamName) >>
 | 
			
		||||
      graph.Out(kClassificationsTag);
 | 
			
		||||
  task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >>
 | 
			
		||||
      graph.Out(kImageTag);
 | 
			
		||||
  if (enable_flow_limiting) {
 | 
			
		||||
    return tasks::core::AddFlowLimiterCalculator(graph, task_subgraph,
 | 
			
		||||
                                                 {kImageTag, kNormRectTag},
 | 
			
		||||
                                                 kClassificationResultTag);
 | 
			
		||||
    return tasks::core::AddFlowLimiterCalculator(
 | 
			
		||||
        graph, task_subgraph, {kImageTag, kNormRectTag}, kClassificationsTag);
 | 
			
		||||
  }
 | 
			
		||||
  graph.In(kImageTag) >> task_subgraph.In(kImageTag);
 | 
			
		||||
  graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag);
 | 
			
		||||
| 
						 | 
				
			
			@ -125,13 +125,14 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
 | 
			
		|||
          if (status_or_packets.value()[kImageOutStreamName].IsEmpty()) {
 | 
			
		||||
            return;
 | 
			
		||||
          }
 | 
			
		||||
          Packet classification_result_packet =
 | 
			
		||||
              status_or_packets.value()[kClassificationResultStreamName];
 | 
			
		||||
          Packet classifications_packet =
 | 
			
		||||
              status_or_packets.value()[kClassificationsStreamName];
 | 
			
		||||
          Packet image_packet = status_or_packets.value()[kImageOutStreamName];
 | 
			
		||||
          result_callback(
 | 
			
		||||
              classification_result_packet.Get<ClassificationResult>(),
 | 
			
		||||
              ConvertToClassificationResult(
 | 
			
		||||
                  classifications_packet.Get<ClassificationResult>()),
 | 
			
		||||
              image_packet.Get<Image>(),
 | 
			
		||||
              classification_result_packet.Timestamp().Value() /
 | 
			
		||||
              classifications_packet.Timestamp().Value() /
 | 
			
		||||
                  kMicroSecondsPerMilliSecond);
 | 
			
		||||
        };
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -144,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
 | 
			
		|||
      std::move(packets_callback));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
 | 
			
		||||
absl::StatusOr<ImageClassifierResult> ImageClassifier::Classify(
 | 
			
		||||
    Image image,
 | 
			
		||||
    std::optional<core::ImageProcessingOptions> image_processing_options) {
 | 
			
		||||
  if (image.UsesGpu()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -160,11 +161,11 @@ absl::StatusOr<ClassificationResult> ImageClassifier::Classify(
 | 
			
		|||
      ProcessImageData(
 | 
			
		||||
          {{kImageInStreamName, MakePacket<Image>(std::move(image))},
 | 
			
		||||
           {kNormRectName, MakePacket<NormalizedRect>(std::move(norm_rect))}}));
 | 
			
		||||
  return output_packets[kClassificationResultStreamName]
 | 
			
		||||
      .Get<ClassificationResult>();
 | 
			
		||||
  return ConvertToClassificationResult(
 | 
			
		||||
      output_packets[kClassificationsStreamName].Get<ClassificationResult>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
 | 
			
		||||
absl::StatusOr<ImageClassifierResult> ImageClassifier::ClassifyForVideo(
 | 
			
		||||
    Image image, int64 timestamp_ms,
 | 
			
		||||
    std::optional<core::ImageProcessingOptions> image_processing_options) {
 | 
			
		||||
  if (image.UsesGpu()) {
 | 
			
		||||
| 
						 | 
				
			
			@ -184,8 +185,8 @@ absl::StatusOr<ClassificationResult> ImageClassifier::ClassifyForVideo(
 | 
			
		|||
           {kNormRectName,
 | 
			
		||||
            MakePacket<NormalizedRect>(std::move(norm_rect))
 | 
			
		||||
                .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}));
 | 
			
		||||
  return output_packets[kClassificationResultStreamName]
 | 
			
		||||
      .Get<ClassificationResult>();
 | 
			
		||||
  return ConvertToClassificationResult(
 | 
			
		||||
      output_packets[kClassificationsStreamName].Get<ClassificationResult>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
absl::Status ImageClassifier::ClassifyAsync(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,7 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
#include "absl/status/statusor.h"
 | 
			
		||||
#include "mediapipe/framework/formats/image.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/core/base_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +34,10 @@ namespace tasks {
 | 
			
		|||
namespace vision {
 | 
			
		||||
namespace image_classifier {
 | 
			
		||||
 | 
			
		||||
// Alias the shared ClassificationResult struct as result type.
 | 
			
		||||
using ImageClassifierResult =
 | 
			
		||||
    ::mediapipe::tasks::components::containers::ClassificationResult;
 | 
			
		||||
 | 
			
		||||
// The options for configuring a Mediapipe image classifier task.
 | 
			
		||||
struct ImageClassifierOptions {
 | 
			
		||||
  // Base options for configuring MediaPipe Tasks, such as specifying the model
 | 
			
		||||
| 
						 | 
				
			
			@ -56,9 +60,8 @@ struct ImageClassifierOptions {
 | 
			
		|||
  // The user-defined result callback for processing live stream data.
 | 
			
		||||
  // The result callback should only be specified when the running mode is set
 | 
			
		||||
  // to RunningMode::LIVE_STREAM.
 | 
			
		||||
  std::function<void(
 | 
			
		||||
      absl::StatusOr<components::containers::proto::ClassificationResult>,
 | 
			
		||||
      const Image&, int64)>
 | 
			
		||||
  std::function<void(absl::StatusOr<ImageClassifierResult>, const Image&,
 | 
			
		||||
                     int64)>
 | 
			
		||||
      result_callback = nullptr;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -122,7 +125,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
 | 
			
		|||
  // The image can be of any size with format RGB or RGBA.
 | 
			
		||||
  // TODO: describe exact preprocessing steps once
 | 
			
		||||
  // YUVToImageCalculator is integrated.
 | 
			
		||||
  absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
 | 
			
		||||
  absl::StatusOr<ImageClassifierResult> Classify(
 | 
			
		||||
      mediapipe::Image image,
 | 
			
		||||
      std::optional<core::ImageProcessingOptions> image_processing_options =
 | 
			
		||||
          std::nullopt);
 | 
			
		||||
| 
						 | 
				
			
			@ -144,10 +147,10 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
 | 
			
		|||
  // The image can be of any size with format RGB or RGBA. It's required to
 | 
			
		||||
  // provide the video frame's timestamp (in milliseconds). The input timestamps
 | 
			
		||||
  // must be monotonically increasing.
 | 
			
		||||
  absl::StatusOr<components::containers::proto::ClassificationResult>
 | 
			
		||||
  ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
 | 
			
		||||
                   std::optional<core::ImageProcessingOptions>
 | 
			
		||||
                       image_processing_options = std::nullopt);
 | 
			
		||||
  absl::StatusOr<ImageClassifierResult> ClassifyForVideo(
 | 
			
		||||
      mediapipe::Image image, int64 timestamp_ms,
 | 
			
		||||
      std::optional<core::ImageProcessingOptions> image_processing_options =
 | 
			
		||||
          std::nullopt);
 | 
			
		||||
 | 
			
		||||
  // Sends live image data to image classification, and the results will be
 | 
			
		||||
  // available via the "result_callback" provided in the ImageClassifierOptions.
 | 
			
		||||
| 
						 | 
				
			
			@ -170,7 +173,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
 | 
			
		|||
  // increasing.
 | 
			
		||||
  //
 | 
			
		||||
  // The "result_callback" provides:
 | 
			
		||||
  //   - The classification results as a ClassificationResult object.
 | 
			
		||||
  //   - The classification results as an ImageClassifierResult object.
 | 
			
		||||
  //   - The const reference to the corresponding input image that the image
 | 
			
		||||
  //     classifier runs on. Note that the const reference to the image will no
 | 
			
		||||
  //     longer be valid when the callback returns. To access the image data
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -48,6 +48,7 @@ using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		|||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
 | 
			
		||||
 | 
			
		||||
constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
 | 
			
		||||
constexpr char kClassificationsTag[] = "CLASSIFICATIONS";
 | 
			
		||||
constexpr char kImageTag[] = "IMAGE";
 | 
			
		||||
constexpr char kNormRectTag[] = "NORM_RECT";
 | 
			
		||||
constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		||||
| 
						 | 
				
			
			@ -56,6 +57,7 @@ constexpr char kTensorsTag[] = "TENSORS";
 | 
			
		|||
// subgraph.
 | 
			
		||||
struct ImageClassifierOutputStreams {
 | 
			
		||||
  Source<ClassificationResult> classification_result;
 | 
			
		||||
  Source<ClassificationResult> classifications;
 | 
			
		||||
  Source<Image> image;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -71,17 +73,19 @@ struct ImageClassifierOutputStreams {
 | 
			
		|||
//     Describes region of image to perform classification on.
 | 
			
		||||
//     @Optional: rect covering the whole image is used if not specified.
 | 
			
		||||
// Outputs:
 | 
			
		||||
//   CLASSIFICATION_RESULT - ClassificationResult
 | 
			
		||||
//     The aggregated classification result object has two dimensions:
 | 
			
		||||
//     (classification head, classification category)
 | 
			
		||||
//   CLASSIFICATIONS - ClassificationResult @Optional
 | 
			
		||||
//     The classification results aggregated by classifier head.
 | 
			
		||||
//   IMAGE - Image
 | 
			
		||||
//     The image that object detection runs on.
 | 
			
		||||
// TODO: remove this output once Java API migration is over.
 | 
			
		||||
//   CLASSIFICATION_RESULT - (DEPRECATED) ClassificationResult @Optional
 | 
			
		||||
//     The aggregated classification result.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
// node {
 | 
			
		||||
//   calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
 | 
			
		||||
//   input_stream: "IMAGE:image_in"
 | 
			
		||||
//   output_stream: "CLASSIFICATION_RESULT:classification_result_out"
 | 
			
		||||
//   output_stream: "CLASSIFICATIONS:classifications_out"
 | 
			
		||||
//   output_stream: "IMAGE:image_out"
 | 
			
		||||
//   options {
 | 
			
		||||
//     [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
 | 
			
		||||
| 
						 | 
				
			
			@ -115,6 +119,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
 | 
			
		|||
            graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
 | 
			
		||||
    output_streams.classification_result >>
 | 
			
		||||
        graph[Output<ClassificationResult>(kClassificationResultTag)];
 | 
			
		||||
    output_streams.classifications >>
 | 
			
		||||
        graph[Output<ClassificationResult>(kClassificationsTag)];
 | 
			
		||||
    output_streams.image >> graph[Output<Image>(kImageTag)];
 | 
			
		||||
    return graph.GetConfig();
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			@ -170,6 +176,8 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
 | 
			
		|||
    return ImageClassifierOutputStreams{
 | 
			
		||||
        /*classification_result=*/postprocessing[Output<ClassificationResult>(
 | 
			
		||||
            kClassificationResultTag)],
 | 
			
		||||
        /*classifications=*/
 | 
			
		||||
        postprocessing[Output<ClassificationResult>(kClassificationsTag)],
 | 
			
		||||
        /*image=*/preprocessing[Output<Image>(kImageTag)]};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -32,8 +32,8 @@ limitations under the License.
 | 
			
		|||
#include "mediapipe/framework/port/parse_text_proto.h"
 | 
			
		||||
#include "mediapipe/framework/port/status_matchers.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/common.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/category.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/classification_result.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/components/containers/rect.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/image_processing_options.h"
 | 
			
		||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -50,10 +50,9 @@ namespace image_classifier {
 | 
			
		|||
namespace {
 | 
			
		||||
 | 
			
		||||
using ::mediapipe::file::JoinPath;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::Category;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::Classifications;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::Rect;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
 | 
			
		||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
 | 
			
		||||
using ::mediapipe::tasks::vision::core::ImageProcessingOptions;
 | 
			
		||||
using ::testing::HasSubstr;
 | 
			
		||||
using ::testing::Optional;
 | 
			
		||||
| 
						 | 
				
			
			@ -65,83 +64,56 @@ constexpr char kMobileNetQuantizedWithMetadata[] =
 | 
			
		|||
constexpr char kMobileNetQuantizedWithDummyScoreCalibration[] =
 | 
			
		||||
    "mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite";
 | 
			
		||||
 | 
			
		||||
// Checks that the two provided `ClassificationResult` are equal, with a
 | 
			
		||||
// Checks that the two provided `ImageClassifierResult` are equal, with a
 | 
			
		||||
// tolerancy on floating-point score to account for numerical instabilities.
 | 
			
		||||
void ExpectApproximatelyEqual(const ClassificationResult& actual,
 | 
			
		||||
                              const ClassificationResult& expected) {
 | 
			
		||||
void ExpectApproximatelyEqual(const ImageClassifierResult& actual,
 | 
			
		||||
                              const ImageClassifierResult& expected) {
 | 
			
		||||
  const float kPrecision = 1e-6;
 | 
			
		||||
  ASSERT_EQ(actual.classifications_size(), expected.classifications_size());
 | 
			
		||||
  for (int i = 0; i < actual.classifications_size(); ++i) {
 | 
			
		||||
    const Classifications& a = actual.classifications(i);
 | 
			
		||||
    const Classifications& b = expected.classifications(i);
 | 
			
		||||
    EXPECT_EQ(a.head_index(), b.head_index());
 | 
			
		||||
    EXPECT_EQ(a.head_name(), b.head_name());
 | 
			
		||||
    EXPECT_EQ(a.entries_size(), b.entries_size());
 | 
			
		||||
    for (int j = 0; j < a.entries_size(); ++j) {
 | 
			
		||||
      const ClassificationEntry& x = a.entries(j);
 | 
			
		||||
      const ClassificationEntry& y = b.entries(j);
 | 
			
		||||
      EXPECT_EQ(x.timestamp_ms(), y.timestamp_ms());
 | 
			
		||||
      EXPECT_EQ(x.categories_size(), y.categories_size());
 | 
			
		||||
      for (int k = 0; k < x.categories_size(); ++k) {
 | 
			
		||||
        EXPECT_EQ(x.categories(k).index(), y.categories(k).index());
 | 
			
		||||
        EXPECT_EQ(x.categories(k).category_name(),
 | 
			
		||||
                  y.categories(k).category_name());
 | 
			
		||||
        EXPECT_EQ(x.categories(k).display_name(),
 | 
			
		||||
                  y.categories(k).display_name());
 | 
			
		||||
        EXPECT_NEAR(x.categories(k).score(), y.categories(k).score(),
 | 
			
		||||
                    kPrecision);
 | 
			
		||||
      }
 | 
			
		||||
  ASSERT_EQ(actual.classifications.size(), expected.classifications.size());
 | 
			
		||||
  for (int i = 0; i < actual.classifications.size(); ++i) {
 | 
			
		||||
    const Classifications& a = actual.classifications[i];
 | 
			
		||||
    const Classifications& b = expected.classifications[i];
 | 
			
		||||
    EXPECT_EQ(a.head_index, b.head_index);
 | 
			
		||||
    EXPECT_EQ(a.head_name, b.head_name);
 | 
			
		||||
    EXPECT_EQ(a.categories.size(), b.categories.size());
 | 
			
		||||
    for (int j = 0; j < a.categories.size(); ++j) {
 | 
			
		||||
      const Category& x = a.categories[j];
 | 
			
		||||
      const Category& y = b.categories[j];
 | 
			
		||||
      EXPECT_EQ(x.index, y.index);
 | 
			
		||||
      EXPECT_NEAR(x.score, y.score, kPrecision);
 | 
			
		||||
      EXPECT_EQ(x.category_name, y.category_name);
 | 
			
		||||
      EXPECT_EQ(x.display_name, y.display_name);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Generates expected results for "burger.jpg" using kMobileNetFloatWithMetadata
 | 
			
		||||
// with max_results set to 3.
 | 
			
		||||
ClassificationResult GenerateBurgerResults(int64 timestamp) {
 | 
			
		||||
  return ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
      absl::StrFormat(R"pb(classifications {
 | 
			
		||||
                             entries {
 | 
			
		||||
                               categories {
 | 
			
		||||
                                 index: 934
 | 
			
		||||
                                 score: 0.7939592
 | 
			
		||||
                                 category_name: "cheeseburger"
 | 
			
		||||
                               }
 | 
			
		||||
                               categories {
 | 
			
		||||
                                 index: 932
 | 
			
		||||
                                 score: 0.027392805
 | 
			
		||||
                                 category_name: "bagel"
 | 
			
		||||
                               }
 | 
			
		||||
                               categories {
 | 
			
		||||
                                 index: 925
 | 
			
		||||
                                 score: 0.019340655
 | 
			
		||||
                                 category_name: "guacamole"
 | 
			
		||||
                               }
 | 
			
		||||
                               timestamp_ms: %d
 | 
			
		||||
                             }
 | 
			
		||||
                             head_index: 0
 | 
			
		||||
                             head_name: "probability"
 | 
			
		||||
                           })pb",
 | 
			
		||||
                      timestamp));
 | 
			
		||||
ImageClassifierResult GenerateBurgerResults() {
 | 
			
		||||
  ImageClassifierResult result;
 | 
			
		||||
  result.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/934, /*score=*/0.793959200,
 | 
			
		||||
           /*category_name=*/"cheeseburger"},
 | 
			
		||||
          {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"},
 | 
			
		||||
          {/*index=*/925, /*score=*/0.019340655,
 | 
			
		||||
           /*category_name=*/"guacamole"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Generates expected results for "multi_objects.jpg" using
 | 
			
		||||
// kMobileNetFloatWithMetadata with max_results set to 1 and the right bounding
 | 
			
		||||
// box set around the soccer ball.
 | 
			
		||||
ClassificationResult GenerateSoccerBallResults(int64 timestamp) {
 | 
			
		||||
  return ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
      absl::StrFormat(R"pb(classifications {
 | 
			
		||||
                             entries {
 | 
			
		||||
                               categories {
 | 
			
		||||
                                 index: 806
 | 
			
		||||
                                 score: 0.996527493
 | 
			
		||||
                                 category_name: "soccer ball"
 | 
			
		||||
                               }
 | 
			
		||||
                               timestamp_ms: %d
 | 
			
		||||
                             }
 | 
			
		||||
                             head_index: 0
 | 
			
		||||
                             head_name: "probability"
 | 
			
		||||
                           })pb",
 | 
			
		||||
                      timestamp));
 | 
			
		||||
ImageClassifierResult GenerateSoccerBallResults() {
 | 
			
		||||
  ImageClassifierResult result;
 | 
			
		||||
  result.classifications.emplace_back(
 | 
			
		||||
      Classifications{/*categories=*/{{/*index=*/806, /*score=*/0.996527493,
 | 
			
		||||
                                       /*category_name=*/"soccer ball"}},
 | 
			
		||||
                      /*head_index=*/0,
 | 
			
		||||
                      /*head_name=*/"probability"});
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A custom OpResolver only containing the Ops required by the test model.
 | 
			
		||||
| 
						 | 
				
			
			@ -260,7 +232,7 @@ TEST_F(CreateTest, FailsWithIllegalCallbackInImageOrVideoMode) {
 | 
			
		|||
    options->base_options.model_asset_path =
 | 
			
		||||
        JoinPath("./", kTestDataDirectory, kMobileNetQuantizedWithMetadata);
 | 
			
		||||
    options->running_mode = running_mode;
 | 
			
		||||
    options->result_callback = [](absl::StatusOr<ClassificationResult>,
 | 
			
		||||
    options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
 | 
			
		||||
                                  const Image& image, int64 timestamp_ms) {};
 | 
			
		||||
 | 
			
		||||
    auto image_classifier = ImageClassifier::Create(std::move(options));
 | 
			
		||||
| 
						 | 
				
			
			@ -336,7 +308,7 @@ TEST_F(ImageModeTest, SucceedsWithFloatModel) {
 | 
			
		|||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, GenerateBurgerResults(0));
 | 
			
		||||
  ExpectApproximatelyEqual(results, GenerateBurgerResults());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
 | 
			
		||||
| 
						 | 
				
			
			@ -355,19 +327,13 @@ TEST_F(ImageModeTest, SucceedsWithQuantizedModel) {
 | 
			
		|||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                                        R"pb(classifications {
 | 
			
		||||
                                               entries {
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 934
 | 
			
		||||
                                                   score: 0.97265625
 | 
			
		||||
                                                   category_name: "cheeseburger"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 timestamp_ms: 0
 | 
			
		||||
                                               }
 | 
			
		||||
                                               head_index: 0
 | 
			
		||||
                                               head_name: "probability"
 | 
			
		||||
                                             })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(
 | 
			
		||||
      Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.97265625,
 | 
			
		||||
                                       /*category_name=*/"cheeseburger"}},
 | 
			
		||||
                      /*head_index=*/0,
 | 
			
		||||
                      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
 | 
			
		||||
| 
						 | 
				
			
			@ -383,19 +349,13 @@ TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) {
 | 
			
		|||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                                        R"pb(classifications {
 | 
			
		||||
                                               entries {
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 934
 | 
			
		||||
                                                   score: 0.7939592
 | 
			
		||||
                                                   category_name: "cheeseburger"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 timestamp_ms: 0
 | 
			
		||||
                                               }
 | 
			
		||||
                                               head_index: 0
 | 
			
		||||
                                               head_name: "probability"
 | 
			
		||||
                                             })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(
 | 
			
		||||
      Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.7939592,
 | 
			
		||||
                                       /*category_name=*/"cheeseburger"}},
 | 
			
		||||
                      /*head_index=*/0,
 | 
			
		||||
                      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
 | 
			
		||||
| 
						 | 
				
			
			@ -411,24 +371,15 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) {
 | 
			
		|||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                                        R"pb(classifications {
 | 
			
		||||
                                               entries {
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 934
 | 
			
		||||
                                                   score: 0.7939592
 | 
			
		||||
                                                   category_name: "cheeseburger"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 932
 | 
			
		||||
                                                   score: 0.027392805
 | 
			
		||||
                                                   category_name: "bagel"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 timestamp_ms: 0
 | 
			
		||||
                                               }
 | 
			
		||||
                                               head_index: 0
 | 
			
		||||
                                               head_name: "probability"
 | 
			
		||||
                                             })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/934, /*score=*/0.7939592,
 | 
			
		||||
           /*category_name=*/"cheeseburger"},
 | 
			
		||||
          {/*index=*/932, /*score=*/0.027392805, /*category_name=*/"bagel"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
 | 
			
		||||
| 
						 | 
				
			
			@ -445,29 +396,17 @@ TEST_F(ImageModeTest, SucceedsWithAllowlistOption) {
 | 
			
		|||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                                        R"pb(classifications {
 | 
			
		||||
                                               entries {
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 934
 | 
			
		||||
                                                   score: 0.7939592
 | 
			
		||||
                                                   category_name: "cheeseburger"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 925
 | 
			
		||||
                                                   score: 0.019340655
 | 
			
		||||
                                                   category_name: "guacamole"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 963
 | 
			
		||||
                                                   score: 0.0063278517
 | 
			
		||||
                                                   category_name: "meat loaf"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 timestamp_ms: 0
 | 
			
		||||
                                               }
 | 
			
		||||
                                               head_index: 0
 | 
			
		||||
                                               head_name: "probability"
 | 
			
		||||
                                             })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/934, /*score=*/0.7939592,
 | 
			
		||||
           /*category_name=*/"cheeseburger"},
 | 
			
		||||
          {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
 | 
			
		||||
          {/*index=*/963, /*score=*/0.0063278517,
 | 
			
		||||
           /*category_name=*/"meat loaf"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
 | 
			
		||||
| 
						 | 
				
			
			@ -484,29 +423,17 @@ TEST_F(ImageModeTest, SucceedsWithDenylistOption) {
 | 
			
		|||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                                        R"pb(classifications {
 | 
			
		||||
                                               entries {
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 934
 | 
			
		||||
                                                   score: 0.7939592
 | 
			
		||||
                                                   category_name: "cheeseburger"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 925
 | 
			
		||||
                                                   score: 0.019340655
 | 
			
		||||
                                                   category_name: "guacamole"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 963
 | 
			
		||||
                                                   score: 0.0063278517
 | 
			
		||||
                                                   category_name: "meat loaf"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 timestamp_ms: 0
 | 
			
		||||
                                               }
 | 
			
		||||
                                               head_index: 0
 | 
			
		||||
                                               head_name: "probability"
 | 
			
		||||
                                             })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/934, /*score=*/0.7939592,
 | 
			
		||||
           /*category_name=*/"cheeseburger"},
 | 
			
		||||
          {/*index=*/925, /*score=*/0.019340655, /*category_name=*/"guacamole"},
 | 
			
		||||
          {/*index=*/963, /*score=*/0.0063278517,
 | 
			
		||||
           /*category_name=*/"meat loaf"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
 | 
			
		||||
| 
						 | 
				
			
			@ -525,19 +452,13 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) {
 | 
			
		|||
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(image));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                                        R"pb(classifications {
 | 
			
		||||
                                               entries {
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 934
 | 
			
		||||
                                                   score: 0.725648628
 | 
			
		||||
                                                   category_name: "cheeseburger"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 timestamp_ms: 0
 | 
			
		||||
                                               }
 | 
			
		||||
                                               head_index: 0
 | 
			
		||||
                                               head_name: "probability"
 | 
			
		||||
                                             })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(
 | 
			
		||||
      Classifications{/*categories=*/{{/*index=*/934, /*score=*/0.725648628,
 | 
			
		||||
                                       /*category_name=*/"cheeseburger"}},
 | 
			
		||||
                      /*head_index=*/0,
 | 
			
		||||
                      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
 | 
			
		||||
| 
						 | 
				
			
			@ -557,7 +478,7 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterest) {
 | 
			
		|||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
 | 
			
		||||
                                            image, image_processing_options));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results, GenerateSoccerBallResults(0));
 | 
			
		||||
  ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithRotation) {
 | 
			
		||||
| 
						 | 
				
			
			@ -581,29 +502,17 @@ TEST_F(ImageModeTest, SucceedsWithRotation) {
 | 
			
		|||
  // Results differ slightly from the non-rotated image, but that's expected
 | 
			
		||||
  // as models are very sensitive to the slightest numerical differences
 | 
			
		||||
  // introduced by the rotation and JPG encoding.
 | 
			
		||||
  ExpectApproximatelyEqual(results, ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                                        R"pb(classifications {
 | 
			
		||||
                                               entries {
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 934
 | 
			
		||||
                                                   score: 0.6371766
 | 
			
		||||
                                                   category_name: "cheeseburger"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 963
 | 
			
		||||
                                                   score: 0.049443405
 | 
			
		||||
                                                   category_name: "meat loaf"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 categories {
 | 
			
		||||
                                                   index: 925
 | 
			
		||||
                                                   score: 0.047918003
 | 
			
		||||
                                                   category_name: "guacamole"
 | 
			
		||||
                                                 }
 | 
			
		||||
                                                 timestamp_ms: 0
 | 
			
		||||
                                               }
 | 
			
		||||
                                               head_index: 0
 | 
			
		||||
                                               head_name: "probability"
 | 
			
		||||
                                             })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(Classifications{
 | 
			
		||||
      /*categories=*/{
 | 
			
		||||
          {/*index=*/934, /*score=*/0.6371766,
 | 
			
		||||
           /*category_name=*/"cheeseburger"},
 | 
			
		||||
          {/*index=*/963, /*score=*/0.049443405, /*category_name=*/"meat loaf"},
 | 
			
		||||
          {/*index=*/925, /*score=*/0.047918003,
 | 
			
		||||
           /*category_name=*/"guacamole"}},
 | 
			
		||||
      /*head_index=*/0,
 | 
			
		||||
      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
 | 
			
		||||
| 
						 | 
				
			
			@ -624,20 +533,13 @@ TEST_F(ImageModeTest, SucceedsWithRegionOfInterestAndRotation) {
 | 
			
		|||
  MP_ASSERT_OK_AND_ASSIGN(auto results, image_classifier->Classify(
 | 
			
		||||
                                            image, image_processing_options));
 | 
			
		||||
 | 
			
		||||
  ExpectApproximatelyEqual(results,
 | 
			
		||||
                           ParseTextProtoOrDie<ClassificationResult>(
 | 
			
		||||
                               R"pb(classifications {
 | 
			
		||||
                                      entries {
 | 
			
		||||
                                        categories {
 | 
			
		||||
                                          index: 560
 | 
			
		||||
                                          score: 0.6522213
 | 
			
		||||
                                          category_name: "folding chair"
 | 
			
		||||
                                        }
 | 
			
		||||
                                        timestamp_ms: 0
 | 
			
		||||
                                      }
 | 
			
		||||
                                      head_index: 0
 | 
			
		||||
                                      head_name: "probability"
 | 
			
		||||
                                    })pb"));
 | 
			
		||||
  ImageClassifierResult expected;
 | 
			
		||||
  expected.classifications.emplace_back(
 | 
			
		||||
      Classifications{/*categories=*/{{/*index=*/560, /*score=*/0.6522213,
 | 
			
		||||
                                       /*category_name=*/"folding chair"}},
 | 
			
		||||
                      /*head_index=*/0,
 | 
			
		||||
                      /*head_name=*/"probability"});
 | 
			
		||||
  ExpectApproximatelyEqual(results, expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Testing all these once with ImageClassifier.
 | 
			
		||||
| 
						 | 
				
			
			@ -774,7 +676,7 @@ TEST_F(VideoModeTest, Succeeds) {
 | 
			
		|||
  for (int i = 0; i < iterations; ++i) {
 | 
			
		||||
    MP_ASSERT_OK_AND_ASSIGN(auto results,
 | 
			
		||||
                            image_classifier->ClassifyForVideo(image, i));
 | 
			
		||||
    ExpectApproximatelyEqual(results, GenerateBurgerResults(i));
 | 
			
		||||
    ExpectApproximatelyEqual(results, GenerateBurgerResults());
 | 
			
		||||
  }
 | 
			
		||||
  MP_ASSERT_OK(image_classifier->Close());
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -800,7 +702,7 @@ TEST_F(VideoModeTest, SucceedsWithRegionOfInterest) {
 | 
			
		|||
    MP_ASSERT_OK_AND_ASSIGN(
 | 
			
		||||
        auto results,
 | 
			
		||||
        image_classifier->ClassifyForVideo(image, i, image_processing_options));
 | 
			
		||||
    ExpectApproximatelyEqual(results, GenerateSoccerBallResults(i));
 | 
			
		||||
    ExpectApproximatelyEqual(results, GenerateSoccerBallResults());
 | 
			
		||||
  }
 | 
			
		||||
  MP_ASSERT_OK(image_classifier->Close());
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			@ -815,7 +717,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
 | 
			
		||||
  options->running_mode = core::RunningMode::LIVE_STREAM;
 | 
			
		||||
  options->result_callback = [](absl::StatusOr<ClassificationResult>,
 | 
			
		||||
  options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
 | 
			
		||||
                                const Image& image, int64 timestamp_ms) {};
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
 | 
			
		||||
                          ImageClassifier::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -846,7 +748,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
 | 
			
		|||
  options->base_options.model_asset_path =
 | 
			
		||||
      JoinPath("./", kTestDataDirectory, kMobileNetFloatWithMetadata);
 | 
			
		||||
  options->running_mode = core::RunningMode::LIVE_STREAM;
 | 
			
		||||
  options->result_callback = [](absl::StatusOr<ClassificationResult>,
 | 
			
		||||
  options->result_callback = [](absl::StatusOr<ImageClassifierResult>,
 | 
			
		||||
                                const Image& image, int64 timestamp_ms) {};
 | 
			
		||||
  MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageClassifier> image_classifier,
 | 
			
		||||
                          ImageClassifier::Create(std::move(options)));
 | 
			
		||||
| 
						 | 
				
			
			@ -864,7 +766,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) {
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
struct LiveStreamModeResults {
 | 
			
		||||
  ClassificationResult classification_result;
 | 
			
		||||
  ImageClassifierResult classification_result;
 | 
			
		||||
  std::pair<int, int> image_size;
 | 
			
		||||
  int64 timestamp_ms;
 | 
			
		||||
};
 | 
			
		||||
| 
						 | 
				
			
			@ -881,7 +783,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
 | 
			
		|||
  options->running_mode = core::RunningMode::LIVE_STREAM;
 | 
			
		||||
  options->classifier_options.max_results = 3;
 | 
			
		||||
  options->result_callback =
 | 
			
		||||
      [&results](absl::StatusOr<ClassificationResult> classification_result,
 | 
			
		||||
      [&results](absl::StatusOr<ImageClassifierResult> classification_result,
 | 
			
		||||
                 const Image& image, int64 timestamp_ms) {
 | 
			
		||||
        MP_ASSERT_OK(classification_result.status());
 | 
			
		||||
        results.push_back(
 | 
			
		||||
| 
						 | 
				
			
			@ -908,7 +810,7 @@ TEST_F(LiveStreamModeTest, Succeeds) {
 | 
			
		|||
    EXPECT_EQ(result.image_size.first, image.width());
 | 
			
		||||
    EXPECT_EQ(result.image_size.second, image.height());
 | 
			
		||||
    ExpectApproximatelyEqual(result.classification_result,
 | 
			
		||||
                             GenerateBurgerResults(timestamp_ms));
 | 
			
		||||
                             GenerateBurgerResults());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -924,7 +826,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
 | 
			
		|||
  options->running_mode = core::RunningMode::LIVE_STREAM;
 | 
			
		||||
  options->classifier_options.max_results = 1;
 | 
			
		||||
  options->result_callback =
 | 
			
		||||
      [&results](absl::StatusOr<ClassificationResult> classification_result,
 | 
			
		||||
      [&results](absl::StatusOr<ImageClassifierResult> classification_result,
 | 
			
		||||
                 const Image& image, int64 timestamp_ms) {
 | 
			
		||||
        MP_ASSERT_OK(classification_result.status());
 | 
			
		||||
        results.push_back(
 | 
			
		||||
| 
						 | 
				
			
			@ -955,7 +857,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
 | 
			
		|||
    EXPECT_EQ(result.image_size.first, image.width());
 | 
			
		||||
    EXPECT_EQ(result.image_size.second, image.height());
 | 
			
		||||
    ExpectApproximatelyEqual(result.classification_result,
 | 
			
		||||
                             GenerateSoccerBallResults(timestamp_ms));
 | 
			
		||||
                             GenerateSoccerBallResults());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user