Refactor ClassificationResult and ClassificationPostprocessing.
PiperOrigin-RevId: 478444264
This commit is contained in:
parent
1e5cccdc73
commit
03c8ac3641
|
@ -35,9 +35,10 @@ cc_library(
|
|||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
|
@ -64,8 +65,9 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||
"//mediapipe/tasks/cc/components:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||
|
|
|
@ -24,8 +24,9 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -37,6 +38,8 @@ namespace audio_classifier {
|
|||
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kAudioStreamName[] = "audio_in";
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||
|
@ -77,8 +80,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
|||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::proto::ClassifierOptions>(
|
||||
components::ConvertClassifierOptionsToProto(
|
||||
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
|
|
|
@ -23,8 +23,8 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -40,7 +40,7 @@ struct AudioClassifierOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::ClassifierOptions classifier_options;
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
|
||||
// The running mode of the audio classifier. Default to the audio clips mode.
|
||||
// Audio classifier has two running modes:
|
||||
|
@ -59,8 +59,9 @@ struct AudioClassifierOptions {
|
|||
// The user-defined result callback for processing audio stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::AUDIO_STREAM.
|
||||
std::function<void(absl::StatusOr<ClassificationResult>)> result_callback =
|
||||
nullptr;
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
// Performs audio classification on audio clips or audio stream.
|
||||
|
@ -132,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|||
// framed audio clip.
|
||||
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
||||
// and makes `audio_sample_rate` optional.
|
||||
absl::StatusOr<ClassificationResult> Classify(mediapipe::Matrix audio_clip,
|
||||
double audio_sample_rate);
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
mediapipe::Matrix audio_clip, double audio_sample_rate);
|
||||
|
||||
// Sends audio data (a block in a continuous audio stream) to perform audio
|
||||
// classification. Only use this method when the AudioClassifier is created
|
||||
|
|
|
@ -31,9 +31,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
|
@ -53,6 +53,7 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
|
||||
constexpr char kAudioTag[] = "AUDIO";
|
||||
|
@ -238,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
|
||||
// Adds postprocessing calculators and connects them to the graph output.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing.GetOptions<
|
||||
tasks::components::ClassificationPostprocessingOptions>()));
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Time aggregation is only needed for performing audio classification on
|
||||
|
|
|
@ -37,8 +37,8 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -49,6 +49,7 @@ namespace {
|
|||
|
||||
using ::absl::StatusOr;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.audio.audio_classifier.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message AudioClassifierGraphOptions {
|
||||
|
@ -31,7 +31,7 @@ message AudioClassifierGraphOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
|
||||
// The default sample rate of the input audio. Must be set when the
|
||||
// AudioClassifier is configured to process audio stream data.
|
||||
|
|
|
@ -58,65 +58,6 @@ cc_library(
|
|||
|
||||
# TODO: Enable this test
|
||||
|
||||
cc_library(
|
||||
name = "classifier_options",
|
||||
srcs = ["classifier_options.cc"],
|
||||
hdrs = ["classifier_options.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classification_postprocessing_options_proto",
|
||||
srcs = ["classification_postprocessing_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_postprocessing",
|
||||
srcs = ["classification_postprocessing.cc"],
|
||||
hdrs = ["classification_postprocessing.h"],
|
||||
deps = [
|
||||
":classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "embedder_options",
|
||||
srcs = ["embedder_options.cc"],
|
||||
|
|
|
@ -37,8 +37,8 @@ cc_library(
|
|||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:category_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -128,7 +128,7 @@ cc_library(
|
|||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -25,15 +25,15 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
|
||||
using ::mediapipe::tasks::ClassificationResult;
|
||||
using ::mediapipe::tasks::Classifications;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
|
||||
// Aggregates ClassificationLists into a single ClassificationResult that has
|
||||
// 3 dimensions: (classification head, classification timestamp, classification
|
||||
|
|
|
@ -17,12 +17,13 @@ limitations under the License.
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
|
||||
// Specialized EndLoopCalculator for Tasks specific types.
|
||||
namespace mediapipe::tasks {
|
||||
|
||||
typedef EndLoopCalculator<std::vector<ClassificationResult>>
|
||||
typedef EndLoopCalculator<
|
||||
std::vector<components::containers::proto::ClassificationResult>>
|
||||
EndLoopClassificationResultCalculator;
|
||||
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);
|
||||
|
||||
|
|
|
@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "category_proto",
|
||||
srcs = ["category.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classifications_proto",
|
||||
srcs = ["classifications.proto"],
|
||||
deps = [
|
||||
":category_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embeddings_proto",
|
||||
srcs = ["embeddings.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "landmarks_detection_result_proto",
|
||||
srcs = [
|
||||
|
@ -29,8 +47,3 @@ mediapipe_proto_library(
|
|||
"//mediapipe/framework/formats:rect_proto",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embeddings_proto",
|
||||
srcs = ["embeddings.proto"],
|
||||
)
|
||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks;
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
// A single classification result.
|
||||
message Category {
|
|
@ -15,9 +15,9 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks;
|
||||
package mediapipe.tasks.components.containers.proto;
|
||||
|
||||
import "mediapipe/tasks/cc/components/containers/category.proto";
|
||||
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||
|
||||
// List of predicted categories with an optional timestamp.
|
||||
message ClassificationEntry {
|
64
mediapipe/tasks/cc/components/processors/BUILD
Normal file
64
mediapipe/tasks/cc/components/processors/BUILD
Normal file
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
cc_library(
|
||||
name = "classifier_options",
|
||||
srcs = ["classifier_options.cc"],
|
||||
hdrs = ["classifier_options.h"],
|
||||
deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "classification_postprocessing_graph",
|
||||
srcs = ["classification_postprocessing_graph.cc"],
|
||||
hdrs = ["classification_postprocessing_graph.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:split_vector_calculator",
|
||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:label_map_util",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
|
@ -37,9 +37,9 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp;
|
|||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||
using ::tflite::ProcessUnit;
|
||||
|
@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS";
|
|||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
|
||||
// Performs sanity checks on provided ClassifierOptions.
|
||||
absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) {
|
||||
absl::Status SanityCheckClassifierOptions(
|
||||
const proto::ClassifierOptions& options) {
|
||||
if (options.max_results() == 0) {
|
||||
return CreateStatusWithPayload(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
|
@ -203,7 +205,7 @@ absl::StatusOr<float> GetScoreThreshold(
|
|||
|
||||
// Gets the category allowlist or denylist (if any) as a set of indices.
|
||||
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
||||
const ClassifierOptions& options, const LabelItems& label_items) {
|
||||
const proto::ClassifierOptions& options, const LabelItems& label_items) {
|
||||
absl::flat_hash_set<int> category_indices;
|
||||
// Exit early if no denylist/allowlist.
|
||||
if (options.category_denylist_size() == 0 &&
|
||||
|
@ -239,7 +241,7 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
|||
|
||||
absl::Status ConfigureScoreCalibrationIfAny(
|
||||
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||
ClassificationPostprocessingOptions* options) {
|
||||
proto::ClassificationPostprocessingGraphOptions* options) {
|
||||
const auto* tensor_metadata =
|
||||
metadata_extractor.GetOutputTensorMetadata(tensor_index);
|
||||
if (tensor_metadata == nullptr) {
|
||||
|
@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
|||
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
||||
// classifier options and the (optional) output tensor metadata.
|
||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||
const ClassifierOptions& options,
|
||||
const proto::ClassifierOptions& options,
|
||||
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||
TensorsToClassificationCalculatorOptions* calculator_options) {
|
||||
const auto* tensor_metadata =
|
||||
|
@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator(
|
|||
|
||||
} // namespace
|
||||
|
||||
absl::Status ConfigureClassificationPostprocessing(
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const ModelResources& model_resources,
|
||||
const ClassifierOptions& classifier_options,
|
||||
ClassificationPostprocessingOptions* options) {
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
proto::ClassificationPostprocessingGraphOptions* options) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
|
||||
ASSIGN_OR_RETURN(const auto heads_properties,
|
||||
GetClassificationHeadsProperties(model_resources));
|
||||
|
@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
|
||||
// raw tensors into ClassificationResult objects.
|
||||
// A "ClassificationPostprocessingGraph" converts raw tensors into
|
||||
// ClassificationResult objects.
|
||||
// - Accepts CPU input tensors.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing(
|
|||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The output aggregated classification results.
|
||||
//
|
||||
// The recommended way of using this subgraph is through the GraphBuilder API
|
||||
// using the 'ConfigureClassificationPostprocessing()' function. See header file
|
||||
// for more details.
|
||||
class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||
// The recommended way of using this graph is through the GraphBuilder API
|
||||
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||
// file for more details.
|
||||
class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||
public:
|
||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||
mediapipe::SubgraphContext* sc) override {
|
||||
|
@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
|||
ASSIGN_OR_RETURN(
|
||||
auto classification_result_out,
|
||||
BuildClassificationPostprocessing(
|
||||
sc->Options<ClassificationPostprocessingOptions>(),
|
||||
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||
classification_result_out >>
|
||||
|
@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
|||
}
|
||||
|
||||
private:
|
||||
// Adds an on-device classification postprocessing subgraph into the provided
|
||||
// builder::Graph instance. The classification postprocessing subgraph takes
|
||||
// Adds an on-device classification postprocessing graph into the provided
|
||||
// builder::Graph instance. The classification postprocessing graph takes
|
||||
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
||||
// stream containing the output classification results (ClassificationResult).
|
||||
//
|
||||
// options: the on-device ClassificationPostprocessingOptions.
|
||||
// options: the on-device ClassificationPostprocessingGraphOptions.
|
||||
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||
// timestamps that a single ClassificationResult should aggregate.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>>
|
||||
BuildClassificationPostprocessing(
|
||||
const ClassificationPostprocessingOptions& options,
|
||||
const proto::ClassificationPostprocessingGraphOptions& options,
|
||||
Source<std::vector<Tensor>> tensors_in,
|
||||
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||
const int num_heads = options.tensors_to_classifications_options_size();
|
||||
|
@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
|||
kClassificationResultTag)];
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors::
|
||||
ClassificationPostprocessingGraph); // NOLINT
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,32 +13,33 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
// Configures a ClassificationPostprocessing subgraph using the provided model
|
||||
// Configures a ClassificationPostprocessingGraph using the provided model
|
||||
// resources and ClassifierOptions.
|
||||
// - Accepts CPU input tensors.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// auto& postprocessing =
|
||||
// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph");
|
||||
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||
// model_resources,
|
||||
// classifier_options,
|
||||
// &preprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
||||
// &preprocessing.GetOptions<ClassificationPostprocessingGraphOptions>()));
|
||||
//
|
||||
// The resulting ClassificationPostprocessing subgraph has the following I/O:
|
||||
// The resulting ClassificationPostprocessingGraph has the following I/O:
|
||||
// Inputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// The output tensors of an InferenceCalculator.
|
||||
|
@ -49,13 +50,14 @@ namespace components {
|
|||
// Outputs:
|
||||
// CLASSIFICATION_RESULT - ClassificationResult
|
||||
// The output aggregated classification results.
|
||||
absl::Status ConfigureClassificationPostprocessing(
|
||||
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||
const tasks::core::ModelResources& model_resources,
|
||||
const tasks::components::proto::ClassifierOptions& classifier_options,
|
||||
ClassificationPostprocessingOptions* options);
|
||||
const proto::ClassifierOptions& classifier_options,
|
||||
proto::ClassificationPostprocessingGraphOptions* options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -42,9 +42,9 @@ limitations under the License.
|
|||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::Input;
|
||||
|
@ -60,7 +61,7 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::ModelResources;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::proto::Approximately;
|
||||
|
@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.set_max_results(0);
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out);
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
|
||||
|
@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_allowlist("foo");
|
||||
options_in.add_category_denylist("bar");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out);
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
|
||||
|
@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_allowlist("foo");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out);
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
auto status = ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out);
|
||||
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||
EXPECT_THAT(
|
||||
|
@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||
R"pb(score_calibration_options: []
|
||||
|
@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.set_max_results(3);
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||
R"pb(score_calibration_options: []
|
||||
|
@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.set_score_threshold(0.5);
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||
R"pb(score_calibration_options: []
|
||||
|
@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Check label map size and two first elements.
|
||||
EXPECT_EQ(
|
||||
|
@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_allowlist("tench");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Clear label map and compare the rest of the options.
|
||||
options_out.mutable_tensors_to_classifications_options(0)
|
||||
|
@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
options_in.add_category_denylist("background");
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Clear label map and compare the rest of the options.
|
||||
options_out.mutable_tensors_to_classifications_options(0)
|
||||
|
@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
|
|||
auto model_resources,
|
||||
CreateModelResourcesForModel(
|
||||
kQuantizedImageClassifierWithDummyScoreCalibration));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
|
||||
// Check label map size and two first elements.
|
||||
EXPECT_EQ(
|
||||
|
@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
|||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto model_resources,
|
||||
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
|
||||
ClassifierOptions options_in;
|
||||
proto::ClassifierOptions options_in;
|
||||
|
||||
ClassificationPostprocessingOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
||||
options_in, &options_out));
|
||||
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options_in, &options_out));
|
||||
// Check label maps sizes and first two elements.
|
||||
EXPECT_EQ(
|
||||
options_out.tensors_to_classifications_options(0).label_items_size(),
|
||||
|
@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
|||
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||
protected:
|
||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||
absl::string_view model_name, const ClassifierOptions& options,
|
||||
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||
bool connect_timestamps = false) {
|
||||
ASSIGN_OR_RETURN(auto model_resources,
|
||||
CreateModelResourcesForModel(model_name));
|
||||
|
||||
Graph graph;
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||
*model_resources, options,
|
||||
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
||||
&postprocessing
|
||||
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
|
||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||
postprocessing.In(kTensorsTag);
|
||||
if (connect_timestamps) {
|
||||
|
@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
options.set_score_threshold(0.5);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
|
@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||
|
@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(3);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
|
@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller,
|
||||
|
@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
|||
|
||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||
// Build graph.
|
||||
ClassifierOptions options;
|
||||
proto::ClassifierOptions options;
|
||||
options.set_max_results(2);
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||
|
@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
ClassifierOptions* options) {
|
||||
tasks::components::proto::ClassifierOptions options_proto;
|
||||
proto::ClassifierOptions options_proto;
|
||||
options_proto.set_display_names_locale(options->display_names_locale);
|
||||
options_proto.set_max_results(options->max_results);
|
||||
options_proto.set_score_threshold(options->score_threshold);
|
||||
|
@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
|||
return options_proto;
|
||||
}
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
|
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
||||
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace components {
|
||||
namespace processors {
|
||||
|
||||
// Classifier options for MediaPipe C++ classification Tasks.
|
||||
struct ClassifierOptions {
|
||||
|
@ -49,11 +50,12 @@ struct ClassifierOptions {
|
|||
};
|
||||
|
||||
// Converts a ClassifierOptions to a ClassifierOptionsProto.
|
||||
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||
ClassifierOptions* classifier_options);
|
||||
|
||||
} // namespace processors
|
||||
} // namespace components
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
|
@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "category_proto",
|
||||
srcs = ["category.proto"],
|
||||
name = "classifier_options_proto",
|
||||
srcs = ["classifier_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classifications_proto",
|
||||
srcs = ["classifications.proto"],
|
||||
name = "classification_postprocessing_graph_options_proto",
|
||||
srcs = ["classification_postprocessing_graph_options.proto"],
|
||||
deps = [
|
||||
":category_proto",
|
||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||
],
|
||||
)
|
|
@ -15,16 +15,16 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components;
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
|
||||
|
||||
message ClassificationPostprocessingOptions {
|
||||
message ClassificationPostprocessingGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ClassificationPostprocessingOptions ext = 460416950;
|
||||
optional ClassificationPostprocessingGraphOptions ext = 460416950;
|
||||
}
|
||||
|
||||
// Optional mapping between output tensor index and corresponding score
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
|||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe.tasks.components.proto;
|
||||
package mediapipe.tasks.components.processors.proto;
|
||||
|
||||
// Shared options used by all classification tasks.
|
||||
message ClassifierOptions {
|
|
@ -23,11 +23,6 @@ mediapipe_proto_library(
|
|||
srcs = ["segmenter_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "classifier_options_proto",
|
||||
srcs = ["classifier_options.proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "embedder_options_proto",
|
||||
srcs = ["embedder_options.proto"],
|
||||
|
|
|
@ -54,10 +54,10 @@ cc_library(
|
|||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
|
|
|
@ -27,9 +27,9 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
|
@ -49,6 +49,7 @@ using ::mediapipe::api2::Input;
|
|||
using ::mediapipe::api2::Output;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
|
||||
HandGestureRecognizerSubgraphOptions;
|
||||
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
|
||||
|
@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
|||
auto inference_output_tensors = inference.Out(kTensorsTag);
|
||||
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
model_resources, graph_options.classifier_options(),
|
||||
&postprocessing.GetOptions<
|
||||
tasks::components::ClassificationPostprocessingOptions>()));
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, graph_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference_output_tensors >> postprocessing.In(kTensorsTag);
|
||||
auto classification_result =
|
||||
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
|
||||
|
|
|
@ -26,7 +26,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
@ -37,7 +37,5 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message HandGestureRecognizerSubgraphOptions {
|
||||
|
@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions {
|
|||
|
||||
// Options for configuring the gesture classifier behavior, such as score
|
||||
// threshold, number of results, etc.
|
||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
|
||||
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
|
||||
// considered tracked successfully
|
||||
|
|
|
@ -26,11 +26,11 @@ cc_library(
|
|||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||
|
@ -50,9 +50,9 @@ cc_library(
|
|||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/tasks/cc/components:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:base_options",
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/tasks/cc/core:utils",
|
||||
|
|
|
@ -26,9 +26,9 @@ limitations under the License.
|
|||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||
|
@ -56,6 +56,7 @@ constexpr char kSubgraphTypeName[] =
|
|||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
|
||||
// Builds a NormalizedRect covering the entire image.
|
||||
|
@ -107,8 +108,8 @@ ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
|
|||
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||
options->running_mode != core::RunningMode::IMAGE);
|
||||
auto classifier_options_proto =
|
||||
std::make_unique<tasks::components::proto::ClassifierOptions>(
|
||||
components::ConvertClassifierOptionsToProto(
|
||||
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||
components::processors::ConvertClassifierOptionsToProto(
|
||||
&(options->classifier_options)));
|
||||
options_proto->mutable_classifier_options()->Swap(
|
||||
classifier_options_proto.get());
|
||||
|
|
|
@ -23,8 +23,8 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
|
@ -51,12 +51,14 @@ struct ImageClassifierOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
components::ClassifierOptions classifier_options;
|
||||
components::processors::ClassifierOptions classifier_options;
|
||||
|
||||
// The user-defined result callback for processing live stream data.
|
||||
// The result callback should only be specified when the running mode is set
|
||||
// to RunningMode::LIVE_STREAM.
|
||||
std::function<void(absl::StatusOr<ClassificationResult>, const Image&, int64)>
|
||||
std::function<void(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>,
|
||||
const Image&, int64)>
|
||||
result_callback = nullptr;
|
||||
};
|
||||
|
||||
|
@ -113,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA.
|
||||
// TODO: describe exact preprocessing steps once
|
||||
// YUVToImageCalculator is integrated.
|
||||
absl::StatusOr<ClassificationResult> Classify(
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||
mediapipe::Image image,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
|
||||
|
@ -127,9 +129,9 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
// The image can be of any size with format RGB or RGBA. It's required to
|
||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||
// must be monotonically increasing.
|
||||
absl::StatusOr<ClassificationResult> ClassifyForVideo(
|
||||
mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
absl::StatusOr<components::containers::proto::ClassificationResult>
|
||||
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
|
||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||
|
||||
// Sends live image data to image classification, and the results will be
|
||||
// available via the "result_callback" provided in the ImageClassifierOptions.
|
||||
|
|
|
@ -22,11 +22,11 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||
|
@ -43,6 +43,7 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
|
@ -152,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
|
||||
// Adds postprocessing calculators and connects them to the graph output.
|
||||
auto& postprocessing = graph.AddNode(
|
||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing.GetOptions<
|
||||
tasks::components::ClassificationPostprocessingOptions>()));
|
||||
"mediapipe.tasks.components.processors."
|
||||
"ClassificationPostprocessingGraph");
|
||||
MP_RETURN_IF_ERROR(
|
||||
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||
model_resources, task_options.classifier_options(),
|
||||
&postprocessing
|
||||
.GetOptions<components::processors::proto::
|
||||
ClassificationPostprocessingGraphOptions>()));
|
||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||
|
||||
// Outputs the aggregated classification result as the subgraph output
|
||||
|
|
|
@ -32,8 +32,8 @@ limitations under the License.
|
|||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -48,6 +48,9 @@ namespace image_classifier {
|
|||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
|
||||
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||
using ::testing::HasSubstr;
|
||||
using ::testing::Optional;
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
|||
package mediapipe.tasks.vision.image_classifier.proto;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message ImageClassifierGraphOptions {
|
||||
|
@ -31,5 +31,5 @@ message ImageClassifierGraphOptions {
|
|||
|
||||
// Options for configuring the classifier behavior, such as score threshold,
|
||||
// number of results, etc.
|
||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
||||
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ py_library(
|
|||
name = "category",
|
||||
srcs = ["category.py"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/components/containers:category_py_pb2",
|
||||
"//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
|
||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import dataclasses
|
||||
from typing import Any
|
||||
|
||||
from mediapipe.tasks.cc.components.containers import category_pb2
|
||||
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||
|
||||
_CategoryProto = category_pb2.Category
|
||||
|
|
Loading…
Reference in New Issue
Block a user