Refactor ClassificationResult and ClassificationPostprocessing.

PiperOrigin-RevId: 478444264
This commit is contained in:
MediaPipe Team 2022-10-03 01:58:41 -07:00 committed by Copybara-Service
parent 1e5cccdc73
commit 03c8ac3641
37 changed files with 329 additions and 274 deletions

View File

@ -35,9 +35,10 @@ cc_library(
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
@ -64,8 +65,9 @@ cc_library(
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components:classifier_options",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",

View File

@ -24,8 +24,9 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/classifier_options.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -37,6 +38,8 @@ namespace audio_classifier {
namespace {
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultStreamName[] = "classification_result_out";
@ -77,8 +80,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode == core::RunningMode::AUDIO_STREAM);
auto classifier_options_proto =
std::make_unique<tasks::components::proto::ClassifierOptions>(
components::ConvertClassifierOptionsToProto(
std::make_unique<components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get());

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/classifier_options.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe {
@ -40,7 +40,7 @@ struct AudioClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
components::ClassifierOptions classifier_options;
components::processors::ClassifierOptions classifier_options;
// The running mode of the audio classifier. Default to the audio clips mode.
// Audio classifier has two running modes:
@ -59,8 +59,9 @@ struct AudioClassifierOptions {
// The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM.
std::function<void(absl::StatusOr<ClassificationResult>)> result_callback =
nullptr;
std::function<void(
absl::StatusOr<components::containers::proto::ClassificationResult>)>
result_callback = nullptr;
};
// Performs audio classification on audio clips or audio stream.
@ -132,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// framed audio clip.
// TODO: Use `sample_rate` in AudioClassifierOptions by default
// and makes `audio_sample_rate` optional.
absl::StatusOr<ClassificationResult> Classify(mediapipe::Matrix audio_clip,
double audio_sample_rate);
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
mediapipe::Matrix audio_clip, double audio_sample_rate);
// Sends audio data (a block in a continuous audio stream) to perform audio
// classification. Only use this method when the AudioClassifier is created

View File

@ -31,9 +31,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -53,6 +53,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
constexpr char kAudioTag[] = "AUDIO";
@ -238,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, task_options.classifier_options(),
&postprocessing.GetOptions<
tasks::components::ClassificationPostprocessingOptions>()));
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on

View File

@ -37,8 +37,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe {
@ -49,6 +49,7 @@ namespace {
using ::absl::StatusOr;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr;
using ::testing::Optional;

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message AudioClassifierGraphOptions {
@ -31,7 +31,7 @@ message AudioClassifierGraphOptions {
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2;
optional components.processors.proto.ClassifierOptions classifier_options = 2;
// The default sample rate of the input audio. Must be set when the
// AudioClassifier is configured to process audio stream data.

View File

@ -58,65 +58,6 @@ cc_library(
# TODO: Enable this test
cc_library(
name = "classifier_options",
srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
)
mediapipe_proto_library(
name = "classification_postprocessing_options_proto",
srcs = ["classification_postprocessing_options.proto"],
deps = [
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
],
)
cc_library(
name = "classification_postprocessing",
srcs = ["classification_postprocessing.cc"],
hdrs = ["classification_postprocessing.h"],
deps = [
":classification_postprocessing_options_cc_proto",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
cc_library(
name = "embedder_options",
srcs = ["embedder_options.cc"],

View File

@ -37,8 +37,8 @@ cc_library(
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers:category_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status",
],
alwayslink = 1,
@ -128,7 +128,7 @@ cc_library(
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
],
alwayslink = 1,
)

View File

@ -25,15 +25,15 @@
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe {
namespace api2 {
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
using ::mediapipe::tasks::ClassificationResult;
using ::mediapipe::tasks::Classifications;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into a single ClassificationResult that has
// 3 dimensions: (classification head, classification timestamp, classification

View File

@ -17,12 +17,13 @@ limitations under the License.
#include <vector>
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
// Specialized EndLoopCalculator for Tasks specific types.
namespace mediapipe::tasks {
typedef EndLoopCalculator<std::vector<ClassificationResult>>
typedef EndLoopCalculator<
std::vector<components::containers::proto::ClassificationResult>>
EndLoopClassificationResultCalculator;
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);

View File

@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "category_proto",
srcs = ["category.proto"],
)
mediapipe_proto_library(
name = "classifications_proto",
srcs = ["classifications.proto"],
deps = [
":category_proto",
],
)
mediapipe_proto_library(
name = "embeddings_proto",
srcs = ["embeddings.proto"],
)
mediapipe_proto_library(
name = "landmarks_detection_result_proto",
srcs = [
@ -29,8 +47,3 @@ mediapipe_proto_library(
"//mediapipe/framework/formats:rect_proto",
],
)
mediapipe_proto_library(
name = "embeddings_proto",
srcs = ["embeddings.proto"],
)

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks;
package mediapipe.tasks.components.containers.proto;
// A single classification result.
message Category {

View File

@ -15,9 +15,9 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks;
package mediapipe.tasks.components.containers.proto;
import "mediapipe/tasks/cc/components/containers/category.proto";
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
// List of predicted categories with an optional timestamp.
message ClassificationEntry {

View File

@ -0,0 +1,64 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "classifier_options",
srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"],
)
cc_library(
name = "classification_postprocessing_graph",
srcs = ["classification_postprocessing_graph.cc"],
hdrs = ["classification_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include <stdint.h>
@ -37,9 +37,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -51,6 +51,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::ClassifierOptions;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit;
@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Performs sanity checks on provided ClassifierOptions.
absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) {
absl::Status SanityCheckClassifierOptions(
const proto::ClassifierOptions& options) {
if (options.max_results() == 0) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
@ -203,7 +205,7 @@ absl::StatusOr<float> GetScoreThreshold(
// Gets the category allowlist or denylist (if any) as a set of indices.
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
const ClassifierOptions& options, const LabelItems& label_items) {
const proto::ClassifierOptions& options, const LabelItems& label_items) {
absl::flat_hash_set<int> category_indices;
// Exit early if no denylist/allowlist.
if (options.category_denylist_size() == 0 &&
@ -239,7 +241,7 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
absl::Status ConfigureScoreCalibrationIfAny(
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
ClassificationPostprocessingOptions* options) {
proto::ClassificationPostprocessingGraphOptions* options) {
const auto* tensor_metadata =
metadata_extractor.GetOutputTensorMetadata(tensor_index);
if (tensor_metadata == nullptr) {
@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
// Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator(
const ClassifierOptions& options,
const proto::ClassifierOptions& options,
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
TensorsToClassificationCalculatorOptions* calculator_options) {
const auto* tensor_metadata =
@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator(
} // namespace
absl::Status ConfigureClassificationPostprocessing(
absl::Status ConfigureClassificationPostprocessingGraph(
const ModelResources& model_resources,
const ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options) {
const proto::ClassifierOptions& classifier_options,
proto::ClassificationPostprocessingGraphOptions* options) {
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
ASSIGN_OR_RETURN(const auto heads_properties,
GetClassificationHeadsProperties(model_resources));
@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing(
return absl::OkStatus();
}
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
// raw tensors into ClassificationResult objects.
// A "ClassificationPostprocessingGraph" converts raw tensors into
// ClassificationResult objects.
// - Accepts CPU input tensors.
//
// Inputs:
@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing(
// CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results.
//
// The recommended way of using this subgraph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessing()' function. See header file
// for more details.
class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
// The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
// file for more details.
class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override {
@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
ASSIGN_OR_RETURN(
auto classification_result_out,
BuildClassificationPostprocessing(
sc->Options<ClassificationPostprocessingOptions>(),
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
classification_result_out >>
@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
}
private:
// Adds an on-device classification postprocessing subgraph into the provided
// builder::Graph instance. The classification postprocessing subgraph takes
// Adds an on-device classification postprocessing graph into the provided
// builder::Graph instance. The classification postprocessing graph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
// stream containing the output classification results (ClassificationResult).
//
// options: the on-device ClassificationPostprocessingOptions.
// options: the on-device ClassificationPostprocessingGraphOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that a single ClassificationResult should aggregate.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>>
BuildClassificationPostprocessing(
const ClassificationPostprocessingOptions& options,
const proto::ClassificationPostprocessingGraphOptions& options,
Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
const int num_heads = options.tensors_to_classifications_options_size();
@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
kClassificationResultTag)];
}
};
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors::
ClassificationPostprocessingGraph); // NOLINT
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -13,32 +13,33 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
#include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Configures a ClassificationPostprocessing subgraph using the provided model
// Configures a ClassificationPostprocessingGraph using the provided model
// resources and ClassifierOptions.
// - Accepts CPU input tensors.
//
// Example usage:
//
// auto& postprocessing =
// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph");
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
// model_resources,
// classifier_options,
// &preprocessing.GetOptions<ClassificationPostprocessingOptions>()));
// &preprocessing.GetOptions<ClassificationPostprocessingGraphOptions>()));
//
// The resulting ClassificationPostprocessing subgraph has the following I/O:
// The resulting ClassificationPostprocessingGraph has the following I/O:
// Inputs:
// TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator.
@ -49,13 +50,14 @@ namespace components {
// Outputs:
// CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results.
absl::Status ConfigureClassificationPostprocessing(
absl::Status ConfigureClassificationPostprocessingGraph(
const tasks::core::ModelResources& model_resources,
const tasks::components::proto::ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options);
const proto::ClassifierOptions& classifier_options,
proto::ClassificationPostprocessingGraphOptions* options);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include <map>
#include <memory>
@ -42,9 +42,9 @@ limitations under the License.
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/util/label_map.pb.h"
@ -53,6 +53,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
namespace {
using ::mediapipe::api2::Input;
@ -60,7 +61,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::ClassifierOptions;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources;
using ::testing::HasSubstr;
using ::testing::proto::Approximately;
@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
options_in.set_max_results(0);
ClassificationPostprocessingOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out);
proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
options_in.add_category_allowlist("foo");
options_in.add_category_denylist("bar");
ClassificationPostprocessingOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out);
proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
options_in.add_category_allowlist("foo");
ClassificationPostprocessingOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out);
proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(
@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
options_in.set_max_results(3);
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
options_in.set_score_threshold(0.5);
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: []
@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
// Check label map size and two first elements.
EXPECT_EQ(
@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
options_in.add_category_allowlist("tench");
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
// Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0)
@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
options_in.add_category_denylist("background");
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
// Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0)
@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
auto model_resources,
CreateModelResourcesForModel(
kQuantizedImageClassifierWithDummyScoreCalibration));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
// Check label map size and two first elements.
EXPECT_EQ(
@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
MP_ASSERT_OK_AND_ASSIGN(
auto model_resources,
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
ClassifierOptions options_in;
proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
options_in, &options_out));
proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
*model_resources, options_in, &options_out));
// Check label maps sizes and first two elements.
EXPECT_EQ(
options_out.tensors_to_classifications_options(0).label_items_size(),
@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
class PostprocessingTest : public tflite_shims::testing::Test {
protected:
absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const ClassifierOptions& options,
absl::string_view model_name, const proto::ClassifierOptions& options,
bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name));
Graph graph;
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
*model_resources, options,
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
&postprocessing
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag);
if (connect_timestamps) {
@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
// Build graph.
ClassifierOptions options;
proto::ClassifierOptions options;
options.set_max_results(3);
options.set_score_threshold(0.5);
MP_ASSERT_OK_AND_ASSIGN(
@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
// Build graph.
ClassifierOptions options;
proto::ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph.
ClassifierOptions options;
proto::ClassifierOptions options;
options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph.
ClassifierOptions options;
proto::ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller,
@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
// Build graph.
ClassifierOptions options;
proto::ClassifierOptions options;
options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
}
} // namespace
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
proto::ClassifierOptions ConvertClassifierOptionsToProto(
ClassifierOptions* options) {
tasks::components::proto::ClassifierOptions options_proto;
proto::ClassifierOptions options_proto;
options_proto.set_display_names_locale(options->display_names_locale);
options_proto.set_max_results(options->max_results);
options_proto.set_score_threshold(options->score_threshold);
@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
return options_proto;
}
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace processors {
// Classifier options for MediaPipe C++ classification Tasks.
struct ClassifierOptions {
@ -49,11 +50,12 @@ struct ClassifierOptions {
};
// Converts a ClassifierOptions to a ClassifierOptionsProto.
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
proto::ClassifierOptions ConvertClassifierOptionsToProto(
ClassifierOptions* classifier_options);
} // namespace processors
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_

View File

@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "category_proto",
srcs = ["category.proto"],
name = "classifier_options_proto",
srcs = ["classifier_options.proto"],
)
mediapipe_proto_library(
name = "classifications_proto",
srcs = ["classifications.proto"],
name = "classification_postprocessing_graph_options_proto",
srcs = ["classification_postprocessing_graph_options.proto"],
deps = [
":category_proto",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
],
)

View File

@ -15,16 +15,16 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks.components;
package mediapipe.tasks.components.processors.proto;
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
message ClassificationPostprocessingOptions {
message ClassificationPostprocessingGraphOptions {
extend mediapipe.CalculatorOptions {
optional ClassificationPostprocessingOptions ext = 460416950;
optional ClassificationPostprocessingGraphOptions ext = 460416950;
}
// Optional mapping between output tensor index and corresponding score

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2";
package mediapipe.tasks.components.proto;
package mediapipe.tasks.components.processors.proto;
// Shared options used by all classification tasks.
message ClassifierOptions {

View File

@ -23,11 +23,6 @@ mediapipe_proto_library(
srcs = ["segmenter_options.proto"],
)
mediapipe_proto_library(
name = "classifier_options_proto",
srcs = ["classifier_options.proto"],
)
mediapipe_proto_library(
name = "embedder_options_proto",
srcs = ["embedder_options.proto"],

View File

@ -54,10 +54,10 @@ cc_library(
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils",

View File

@ -27,9 +27,9 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -49,6 +49,7 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
HandGestureRecognizerSubgraphOptions;
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
auto inference_output_tensors = inference.Out(kTensorsTag);
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, graph_options.classifier_options(),
&postprocessing.GetOptions<
tasks::components::ClassificationPostprocessingOptions>()));
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference_output_tensors >> postprocessing.In(kTensorsTag);
auto classification_result =
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];

View File

@ -26,7 +26,7 @@ mediapipe_proto_library(
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)
@ -37,7 +37,5 @@ mediapipe_proto_library(
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message HandGestureRecognizerSubgraphOptions {
@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions {
// Options for configuring the gesture classifier behavior, such as score
// threshold, number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2;
optional components.processors.proto.ClassifierOptions classifier_options = 2;
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
// considered tracked successfully

View File

@ -26,11 +26,11 @@ cc_library(
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
@ -50,9 +50,9 @@ cc_library(
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:classifier_options",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:utils",

View File

@ -26,9 +26,9 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/classifier_options.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -56,6 +56,7 @@ constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap;
// Builds a NormalizedRect covering the entire image.
@ -107,8 +108,8 @@ ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE);
auto classifier_options_proto =
std::make_unique<tasks::components::proto::ClassifierOptions>(
components::ConvertClassifierOptionsToProto(
std::make_unique<components::processors::proto::ClassifierOptions>(
components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get());

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/classifier_options.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -51,12 +51,14 @@ struct ImageClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
components::ClassifierOptions classifier_options;
components::processors::ClassifierOptions classifier_options;
// The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<ClassificationResult>, const Image&, int64)>
std::function<void(
absl::StatusOr<components::containers::proto::ClassificationResult>,
const Image&, int64)>
result_callback = nullptr;
};
@ -113,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA.
// TODO: describe exact preprocessing steps once
// YUVToImageCalculator is integrated.
absl::StatusOr<ClassificationResult> Classify(
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
@ -127,8 +129,8 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing.
absl::StatusOr<ClassificationResult> ClassifyForVideo(
mediapipe::Image image, int64 timestamp_ms,
absl::StatusOr<components::containers::proto::ClassificationResult>
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
// Sends live image data to image classification, and the results will be

View File

@ -22,11 +22,11 @@ limitations under the License.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
@ -43,6 +43,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
@ -152,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
"mediapipe.tasks.components.processors."
"ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, task_options.classifier_options(),
&postprocessing.GetOptions<
tasks::components::ClassificationPostprocessingOptions>()));
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the aggregated classification result as the subgraph output

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h"
@ -48,6 +48,9 @@ namespace image_classifier {
namespace {
using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
using ::testing::HasSubstr;
using ::testing::Optional;

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
],
)

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_classifier.proto;
import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageClassifierGraphOptions {
@ -31,5 +31,5 @@ message ImageClassifierGraphOptions {
// Options for configuring the classifier behavior, such as score threshold,
// number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2;
optional components.processors.proto.ClassifierOptions classifier_options = 2;
}

View File

@ -31,7 +31,7 @@ py_library(
name = "category",
srcs = ["category.py"],
deps = [
"//mediapipe/tasks/cc/components/containers:category_py_pb2",
"//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies",
],
)

View File

@ -16,7 +16,7 @@
import dataclasses
from typing import Any
from mediapipe.tasks.cc.components.containers import category_pb2
from mediapipe.tasks.cc.components.containers.proto import category_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_CategoryProto = category_pb2.Category