Refactor ClassificationResult and ClassificationPostprocessing.

PiperOrigin-RevId: 478444264
This commit is contained in:
MediaPipe Team 2022-10-03 01:58:41 -07:00 committed by Copybara-Service
parent 1e5cccdc73
commit 03c8ac3641
37 changed files with 329 additions and 274 deletions

View File

@ -35,9 +35,10 @@ cc_library(
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
"//mediapipe/tasks/cc/components:classification_postprocessing", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
@ -64,8 +65,9 @@ cc_library(
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode", "//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components:classifier_options", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",

View File

@ -24,8 +24,9 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -37,6 +38,8 @@ namespace audio_classifier {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultStreamName[] = "classification_result_out"; constexpr char kClassificationResultStreamName[] = "classification_result_out";
@ -77,8 +80,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode == core::RunningMode::AUDIO_STREAM); options->running_mode == core::RunningMode::AUDIO_STREAM);
auto classifier_options_proto = auto classifier_options_proto =
std::make_unique<tasks::components::proto::ClassifierOptions>( std::make_unique<components::processors::proto::ClassifierOptions>(
components::ConvertClassifierOptionsToProto( components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options))); &(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap( options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get()); classifier_options_proto.get());

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe { namespace mediapipe {
@ -40,7 +40,7 @@ struct AudioClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
components::ClassifierOptions classifier_options; components::processors::ClassifierOptions classifier_options;
// The running mode of the audio classifier. Default to the audio clips mode. // The running mode of the audio classifier. Default to the audio clips mode.
// Audio classifier has two running modes: // Audio classifier has two running modes:
@ -59,8 +59,9 @@ struct AudioClassifierOptions {
// The user-defined result callback for processing audio stream data. // The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM. // to RunningMode::AUDIO_STREAM.
std::function<void(absl::StatusOr<ClassificationResult>)> result_callback = std::function<void(
nullptr; absl::StatusOr<components::containers::proto::ClassificationResult>)>
result_callback = nullptr;
}; };
// Performs audio classification on audio clips or audio stream. // Performs audio classification on audio clips or audio stream.
@ -132,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// framed audio clip. // framed audio clip.
// TODO: Use `sample_rate` in AudioClassifierOptions by default // TODO: Use `sample_rate` in AudioClassifierOptions by default
// and makes `audio_sample_rate` optional. // and makes `audio_sample_rate` optional.
absl::StatusOr<ClassificationResult> Classify(mediapipe::Matrix audio_clip, absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
double audio_sample_rate); mediapipe::Matrix audio_clip, double audio_sample_rate);
// Sends audio data (a block in a continuous audio stream) to perform audio // Sends audio data (a block in a continuous audio stream) to perform audio
// classification. Only use this method when the AudioClassifier is created // classification. Only use this method when the AudioClassifier is created

View File

@ -31,9 +31,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -53,6 +53,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
@ -238,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects them to the graph output. // Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
model_resources, task_options.classifier_options(), MP_RETURN_IF_ERROR(
&postprocessing.GetOptions< components::processors::ConfigureClassificationPostprocessingGraph(
tasks::components::ClassificationPostprocessingOptions>())); model_resources, task_options.classifier_options(),
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on // Time aggregation is only needed for performing audio classification on

View File

@ -37,8 +37,8 @@ limitations under the License.
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe { namespace mediapipe {
@ -49,6 +49,7 @@ namespace {
using ::absl::StatusOr; using ::absl::StatusOr;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_classifier.proto; package mediapipe.tasks.audio.audio_classifier.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message AudioClassifierGraphOptions { message AudioClassifierGraphOptions {
@ -31,7 +31,7 @@ message AudioClassifierGraphOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2; optional components.processors.proto.ClassifierOptions classifier_options = 2;
// The default sample rate of the input audio. Must be set when the // The default sample rate of the input audio. Must be set when the
// AudioClassifier is configured to process audio stream data. // AudioClassifier is configured to process audio stream data.

View File

@ -58,65 +58,6 @@ cc_library(
# TODO: Enable this test # TODO: Enable this test
cc_library(
name = "classifier_options",
srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
)
mediapipe_proto_library(
name = "classification_postprocessing_options_proto",
srcs = ["classification_postprocessing_options.proto"],
deps = [
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
],
)
cc_library(
name = "classification_postprocessing",
srcs = ["classification_postprocessing.cc"],
hdrs = ["classification_postprocessing.h"],
deps = [
":classification_postprocessing_options_cc_proto",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "embedder_options", name = "embedder_options",
srcs = ["embedder_options.cc"], srcs = ["embedder_options.cc"],

View File

@ -37,8 +37,8 @@ cc_library(
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
], ],
alwayslink = 1, alwayslink = 1,
@ -128,7 +128,7 @@ cc_library(
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -25,15 +25,15 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions; using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
using ::mediapipe::tasks::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::Classifications; using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into a single ClassificationResult that has // Aggregates ClassificationLists into a single ClassificationResult that has
// 3 dimensions: (classification head, classification timestamp, classification // 3 dimensions: (classification head, classification timestamp, classification

View File

@ -17,12 +17,13 @@ limitations under the License.
#include <vector> #include <vector>
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
// Specialized EndLoopCalculator for Tasks specific types. // Specialized EndLoopCalculator for Tasks specific types.
namespace mediapipe::tasks { namespace mediapipe::tasks {
typedef EndLoopCalculator<std::vector<ClassificationResult>> typedef EndLoopCalculator<
std::vector<components::containers::proto::ClassificationResult>>
EndLoopClassificationResultCalculator; EndLoopClassificationResultCalculator;
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);

View File

@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library(
name = "category_proto",
srcs = ["category.proto"],
)
mediapipe_proto_library(
name = "classifications_proto",
srcs = ["classifications.proto"],
deps = [
":category_proto",
],
)
mediapipe_proto_library(
name = "embeddings_proto",
srcs = ["embeddings.proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "landmarks_detection_result_proto", name = "landmarks_detection_result_proto",
srcs = [ srcs = [
@ -29,8 +47,3 @@ mediapipe_proto_library(
"//mediapipe/framework/formats:rect_proto", "//mediapipe/framework/formats:rect_proto",
], ],
) )
mediapipe_proto_library(
name = "embeddings_proto",
srcs = ["embeddings.proto"],
)

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks; package mediapipe.tasks.components.containers.proto;
// A single classification result. // A single classification result.
message Category { message Category {

View File

@ -15,9 +15,9 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks; package mediapipe.tasks.components.containers.proto;
import "mediapipe/tasks/cc/components/containers/category.proto"; import "mediapipe/tasks/cc/components/containers/proto/category.proto";
// List of predicted categories with an optional timestamp. // List of predicted categories with an optional timestamp.
message ClassificationEntry { message ClassificationEntry {

View File

@ -0,0 +1,64 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "classifier_options",
srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"],
)
cc_library(
name = "classification_postprocessing_graph",
srcs = ["classification_postprocessing_graph.cc"],
hdrs = ["classification_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include <stdint.h> #include <stdint.h>
@ -37,9 +37,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -51,6 +51,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::ClassifierOptions; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit; using ::tflite::ProcessUnit;
@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Performs sanity checks on provided ClassifierOptions. // Performs sanity checks on provided ClassifierOptions.
absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) { absl::Status SanityCheckClassifierOptions(
const proto::ClassifierOptions& options) {
if (options.max_results() == 0) { if (options.max_results() == 0) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -203,7 +205,7 @@ absl::StatusOr<float> GetScoreThreshold(
// Gets the category allowlist or denylist (if any) as a set of indices. // Gets the category allowlist or denylist (if any) as a set of indices.
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny( absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
const ClassifierOptions& options, const LabelItems& label_items) { const proto::ClassifierOptions& options, const LabelItems& label_items) {
absl::flat_hash_set<int> category_indices; absl::flat_hash_set<int> category_indices;
// Exit early if no denylist/allowlist. // Exit early if no denylist/allowlist.
if (options.category_denylist_size() == 0 && if (options.category_denylist_size() == 0 &&
@ -239,7 +241,7 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
absl::Status ConfigureScoreCalibrationIfAny( absl::Status ConfigureScoreCalibrationIfAny(
const ModelMetadataExtractor& metadata_extractor, int tensor_index, const ModelMetadataExtractor& metadata_extractor, int tensor_index,
ClassificationPostprocessingOptions* options) { proto::ClassificationPostprocessingGraphOptions* options) {
const auto* tensor_metadata = const auto* tensor_metadata =
metadata_extractor.GetOutputTensorMetadata(tensor_index); metadata_extractor.GetOutputTensorMetadata(tensor_index);
if (tensor_metadata == nullptr) { if (tensor_metadata == nullptr) {
@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
// Fills in the TensorsToClassificationCalculatorOptions based on the // Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata. // classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator( absl::Status ConfigureTensorsToClassificationCalculator(
const ClassifierOptions& options, const proto::ClassifierOptions& options,
const ModelMetadataExtractor& metadata_extractor, int tensor_index, const ModelMetadataExtractor& metadata_extractor, int tensor_index,
TensorsToClassificationCalculatorOptions* calculator_options) { TensorsToClassificationCalculatorOptions* calculator_options) {
const auto* tensor_metadata = const auto* tensor_metadata =
@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator(
} // namespace } // namespace
absl::Status ConfigureClassificationPostprocessing( absl::Status ConfigureClassificationPostprocessingGraph(
const ModelResources& model_resources, const ModelResources& model_resources,
const ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options) { proto::ClassificationPostprocessingGraphOptions* options) {
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options)); MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
ASSIGN_OR_RETURN(const auto heads_properties, ASSIGN_OR_RETURN(const auto heads_properties,
GetClassificationHeadsProperties(model_resources)); GetClassificationHeadsProperties(model_resources));
@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing(
return absl::OkStatus(); return absl::OkStatus();
} }
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts // A "ClassificationPostprocessingGraph" converts raw tensors into
// raw tensors into ClassificationResult objects. // ClassificationResult objects.
// - Accepts CPU input tensors. // - Accepts CPU input tensors.
// //
// Inputs: // Inputs:
@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing(
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results. // The output aggregated classification results.
// //
// The recommended way of using this subgraph is through the GraphBuilder API // The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessing()' function. See header file // using the 'ConfigureClassificationPostprocessingGraph()' function. See header
// for more details. // file for more details.
class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
public: public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override { mediapipe::SubgraphContext* sc) override {
@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto classification_result_out,
BuildClassificationPostprocessing( BuildClassificationPostprocessing(
sc->Options<ClassificationPostprocessingOptions>(), sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph)); graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
classification_result_out >> classification_result_out >>
@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
} }
private: private:
// Adds an on-device classification postprocessing subgraph into the provided // Adds an on-device classification postprocessing graph into the provided
// builder::Graph instance. The classification postprocessing subgraph takes // builder::Graph instance. The classification postprocessing graph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output // tensors (std::vector<mediapipe::Tensor>) as input and returns one output
// stream containing the output classification results (ClassificationResult). // stream containing the output classification results (ClassificationResult).
// //
// options: the on-device ClassificationPostprocessingOptions. // options: the on-device ClassificationPostprocessingGraphOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess. // tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of // timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that a single ClassificationResult should aggregate. // timestamps that a single ClassificationResult should aggregate.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> absl::StatusOr<Source<ClassificationResult>>
BuildClassificationPostprocessing( BuildClassificationPostprocessing(
const ClassificationPostprocessingOptions& options, const proto::ClassificationPostprocessingGraphOptions& options,
Source<std::vector<Tensor>> tensors_in, Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) { Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
const int num_heads = options.tensors_to_classifications_options_size(); const int num_heads = options.tensors_to_classifications_options_size();
@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
kClassificationResultTag)]; kClassificationResultTag)];
} }
}; };
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors::
ClassificationPostprocessingGraph); // NOLINT
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,32 +13,33 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Configures a ClassificationPostprocessing subgraph using the provided model // Configures a ClassificationPostprocessingGraph using the provided model
// resources and ClassifierOptions. // resources and ClassifierOptions.
// - Accepts CPU input tensors. // - Accepts CPU input tensors.
// //
// Example usage: // Example usage:
// //
// auto& postprocessing = // auto& postprocessing =
// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); // graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph");
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( // MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
// model_resources, // model_resources,
// classifier_options, // classifier_options,
// &preprocessing.GetOptions<ClassificationPostprocessingOptions>())); // &preprocessing.GetOptions<ClassificationPostprocessingGraphOptions>()));
// //
// The resulting ClassificationPostprocessing subgraph has the following I/O: // The resulting ClassificationPostprocessingGraph has the following I/O:
// Inputs: // Inputs:
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator. // The output tensors of an InferenceCalculator.
@ -49,13 +50,14 @@ namespace components {
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results. // The output aggregated classification results.
absl::Status ConfigureClassificationPostprocessing( absl::Status ConfigureClassificationPostprocessingGraph(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const tasks::components::proto::ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options); proto::ClassificationPostprocessingGraphOptions* options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include <map> #include <map>
#include <memory> #include <memory>
@ -42,9 +42,9 @@ limitations under the License.
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
@ -53,6 +53,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
using ::mediapipe::api2::Input; using ::mediapipe::api2::Input;
@ -60,7 +61,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::ClassifierOptions; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::proto::Approximately; using ::testing::proto::Approximately;
@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.set_max_results(0); options_in.set_max_results(0);
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources, auto status = ConfigureClassificationPostprocessingGraph(
options_in, &options_out); *model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_allowlist("foo"); options_in.add_category_allowlist("foo");
options_in.add_category_denylist("bar"); options_in.add_category_denylist("bar");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources, auto status = ConfigureClassificationPostprocessingGraph(
options_in, &options_out); *model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_allowlist("foo"); options_in.add_category_allowlist("foo");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources, auto status = ConfigureClassificationPostprocessingGraph(
options_in, &options_out); *model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT( EXPECT_THAT(
@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: [] R"pb(score_calibration_options: []
@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.set_max_results(3); options_in.set_max_results(3);
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: [] R"pb(score_calibration_options: []
@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.set_score_threshold(0.5); options_in.set_score_threshold(0.5);
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: [] R"pb(score_calibration_options: []
@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Check label map size and two first elements. // Check label map size and two first elements.
EXPECT_EQ( EXPECT_EQ(
@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_allowlist("tench"); options_in.add_category_allowlist("tench");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Clear label map and compare the rest of the options. // Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0) options_out.mutable_tensors_to_classifications_options(0)
@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_denylist("background"); options_in.add_category_denylist("background");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Clear label map and compare the rest of the options. // Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0) options_out.mutable_tensors_to_classifications_options(0)
@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
auto model_resources, auto model_resources,
CreateModelResourcesForModel( CreateModelResourcesForModel(
kQuantizedImageClassifierWithDummyScoreCalibration)); kQuantizedImageClassifierWithDummyScoreCalibration));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Check label map size and two first elements. // Check label map size and two first elements.
EXPECT_EQ( EXPECT_EQ(
@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata)); CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Check label maps sizes and first two elements. // Check label maps sizes and first two elements.
EXPECT_EQ( EXPECT_EQ(
options_out.tensors_to_classifications_options(0).label_items_size(), options_out.tensors_to_classifications_options(0).label_items_size(),
@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
class PostprocessingTest : public tflite_shims::testing::Test { class PostprocessingTest : public tflite_shims::testing::Test {
protected: protected:
absl::StatusOr<OutputStreamPoller> BuildGraph( absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const ClassifierOptions& options, absl::string_view model_name, const proto::ClassifierOptions& options,
bool connect_timestamps = false) { bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources, ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name)); CreateModelResourcesForModel(model_name));
Graph graph; Graph graph;
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
*model_resources, options, *model_resources, options,
&postprocessing.GetOptions<ClassificationPostprocessingOptions>())); &postprocessing
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >> graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag); postprocessing.In(kTensorsTag);
if (connect_timestamps) { if (connect_timestamps) {
@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(3); options.set_max_results(3);
options.set_score_threshold(0.5); options.set_score_threshold(0.5);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
TEST_F(PostprocessingTest, SucceedsWithMetadata) { TEST_F(PostprocessingTest, SucceedsWithMetadata) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(3); options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(3); options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, auto poller,
@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(2); options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, auto poller,
@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
TEST_F(PostprocessingTest, SucceedsWithTimestamps) { TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(2); options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
} }
} // namespace } // namespace
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( proto::ClassifierOptions ConvertClassifierOptionsToProto(
ClassifierOptions* options) { ClassifierOptions* options) {
tasks::components::proto::ClassifierOptions options_proto; proto::ClassifierOptions options_proto;
options_proto.set_display_names_locale(options->display_names_locale); options_proto.set_display_names_locale(options->display_names_locale);
options_proto.set_max_results(options->max_results); options_proto.set_max_results(options->max_results);
options_proto.set_score_threshold(options->score_threshold); options_proto.set_score_threshold(options->score_threshold);
@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
return options_proto; return options_proto;
} }
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Classifier options for MediaPipe C++ classification Tasks. // Classifier options for MediaPipe C++ classification Tasks.
struct ClassifierOptions { struct ClassifierOptions {
@ -49,11 +50,12 @@ struct ClassifierOptions {
}; };
// Converts a ClassifierOptions to a ClassifierOptionsProto. // Converts a ClassifierOptions to a ClassifierOptionsProto.
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( proto::ClassifierOptions ConvertClassifierOptionsToProto(
ClassifierOptions* classifier_options); ClassifierOptions* classifier_options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_

View File

@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "category_proto", name = "classifier_options_proto",
srcs = ["category.proto"], srcs = ["classifier_options.proto"],
) )
mediapipe_proto_library( mediapipe_proto_library(
name = "classifications_proto", name = "classification_postprocessing_graph_options_proto",
srcs = ["classifications.proto"], srcs = ["classification_postprocessing_graph_options.proto"],
deps = [ deps = [
":category_proto", "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
], ],
) )

View File

@ -15,16 +15,16 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components; package mediapipe.tasks.components.processors.proto;
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
message ClassificationPostprocessingOptions { message ClassificationPostprocessingGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ClassificationPostprocessingOptions ext = 460416950; optional ClassificationPostprocessingGraphOptions ext = 460416950;
} }
// Optional mapping between output tensor index and corresponding score // Optional mapping between output tensor index and corresponding score

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
// Shared options used by all classification tasks. // Shared options used by all classification tasks.
message ClassifierOptions { message ClassifierOptions {

View File

@ -23,11 +23,6 @@ mediapipe_proto_library(
srcs = ["segmenter_options.proto"], srcs = ["segmenter_options.proto"],
) )
mediapipe_proto_library(
name = "classifier_options_proto",
srcs = ["classifier_options.proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "embedder_options_proto", name = "embedder_options_proto",
srcs = ["embedder_options.proto"], srcs = ["embedder_options.proto"],

View File

@ -54,10 +54,10 @@ cc_library(
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",

View File

@ -27,9 +27,9 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -49,6 +49,7 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
HandGestureRecognizerSubgraphOptions; HandGestureRecognizerSubgraphOptions;
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
auto inference_output_tensors = inference.Out(kTensorsTag); auto inference_output_tensors = inference.Out(kTensorsTag);
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
model_resources, graph_options.classifier_options(), MP_RETURN_IF_ERROR(
&postprocessing.GetOptions< components::processors::ConfigureClassificationPostprocessingGraph(
tasks::components::ClassificationPostprocessingOptions>())); model_resources, graph_options.classifier_options(),
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference_output_tensors >> postprocessing.In(kTensorsTag); inference_output_tensors >> postprocessing.In(kTensorsTag);
auto classification_result = auto classification_result =
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")]; postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];

View File

@ -26,7 +26,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )
@ -37,7 +37,5 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.hand_gesture_recognizer.proto; package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message HandGestureRecognizerSubgraphOptions { message HandGestureRecognizerSubgraphOptions {
@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions {
// Options for configuring the gesture classifier behavior, such as score // Options for configuring the gesture classifier behavior, such as score
// threshold, number of results, etc. // threshold, number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2; optional components.processors.proto.ClassifierOptions classifier_options = 2;
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
// considered tracked successfully // considered tracked successfully

View File

@ -26,11 +26,11 @@ cc_library(
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
@ -50,9 +50,9 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:classifier_options", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",

View File

@ -26,9 +26,9 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -56,6 +56,7 @@ constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
// Builds a NormalizedRect covering the entire image. // Builds a NormalizedRect covering the entire image.
@ -107,8 +108,8 @@ ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE); options->running_mode != core::RunningMode::IMAGE);
auto classifier_options_proto = auto classifier_options_proto =
std::make_unique<tasks::components::proto::ClassifierOptions>( std::make_unique<components::processors::proto::ClassifierOptions>(
components::ConvertClassifierOptionsToProto( components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options))); &(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap( options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get()); classifier_options_proto.get());

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -51,12 +51,14 @@ struct ImageClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
components::ClassifierOptions classifier_options; components::processors::ClassifierOptions classifier_options;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<ClassificationResult>, const Image&, int64)> std::function<void(
absl::StatusOr<components::containers::proto::ClassificationResult>,
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -113,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: describe exact preprocessing steps once // TODO: describe exact preprocessing steps once
// YUVToImageCalculator is integrated. // YUVToImageCalculator is integrated.
absl::StatusOr<ClassificationResult> Classify( absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
@ -127,9 +129,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<ClassificationResult> ClassifyForVideo( absl::StatusOr<components::containers::proto::ClassificationResult>
mediapipe::Image image, int64 timestamp_ms, ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
// Sends live image data to image classification, and the results will be // Sends live image data to image classification, and the results will be
// available via the "result_callback" provided in the ImageClassifierOptions. // available via the "result_callback" provided in the ImageClassifierOptions.

View File

@ -22,11 +22,11 @@ limitations under the License.
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
@ -43,6 +43,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
@ -152,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects them to the graph output. // Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
model_resources, task_options.classifier_options(), MP_RETURN_IF_ERROR(
&postprocessing.GetOptions< components::processors::ConfigureClassificationPostprocessingGraph(
tasks::components::ClassificationPostprocessingOptions>())); model_resources, task_options.classifier_options(),
&postprocessing
.GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the aggregated classification result as the subgraph output // Outputs the aggregated classification result as the subgraph output

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -48,6 +48,9 @@ namespace image_classifier {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;

View File

@ -24,7 +24,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_classifier.proto; package mediapipe.tasks.vision.image_classifier.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageClassifierGraphOptions { message ImageClassifierGraphOptions {
@ -31,5 +31,5 @@ message ImageClassifierGraphOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2; optional components.processors.proto.ClassifierOptions classifier_options = 2;
} }

View File

@ -31,7 +31,7 @@ py_library(
name = "category", name = "category",
srcs = ["category.py"], srcs = ["category.py"],
deps = [ deps = [
"//mediapipe/tasks/cc/components/containers:category_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
], ],
) )

View File

@ -16,7 +16,7 @@
import dataclasses import dataclasses
from typing import Any from typing import Any
from mediapipe.tasks.cc.components.containers import category_pb2 from mediapipe.tasks.cc.components.containers.proto import category_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_CategoryProto = category_pb2.Category _CategoryProto = category_pb2.Category