Refactor ClassificationResult and ClassificationPostprocessing.
PiperOrigin-RevId: 478444264
This commit is contained in:
parent
1e5cccdc73
commit
03c8ac3641
|
@ -35,9 +35,10 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||||
|
@ -64,8 +65,9 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||||
"//mediapipe/tasks/cc/components:classifier_options",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
|
||||||
|
|
|
@ -24,8 +24,9 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -37,6 +38,8 @@ namespace audio_classifier {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
constexpr char kAudioStreamName[] = "audio_in";
|
constexpr char kAudioStreamName[] = "audio_in";
|
||||||
constexpr char kAudioTag[] = "AUDIO";
|
constexpr char kAudioTag[] = "AUDIO";
|
||||||
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
constexpr char kClassificationResultStreamName[] = "classification_result_out";
|
||||||
|
@ -77,8 +80,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
||||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
options->running_mode == core::RunningMode::AUDIO_STREAM);
|
||||||
auto classifier_options_proto =
|
auto classifier_options_proto =
|
||||||
std::make_unique<tasks::components::proto::ClassifierOptions>(
|
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||||
components::ConvertClassifierOptionsToProto(
|
components::processors::ConvertClassifierOptionsToProto(
|
||||||
&(options->classifier_options)));
|
&(options->classifier_options)));
|
||||||
options_proto->mutable_classifier_options()->Swap(
|
options_proto->mutable_classifier_options()->Swap(
|
||||||
classifier_options_proto.get());
|
classifier_options_proto.get());
|
||||||
|
|
|
@ -23,8 +23,8 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -40,7 +40,7 @@ struct AudioClassifierOptions {
|
||||||
|
|
||||||
// Options for configuring the classifier behavior, such as score threshold,
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
// number of results, etc.
|
// number of results, etc.
|
||||||
components::ClassifierOptions classifier_options;
|
components::processors::ClassifierOptions classifier_options;
|
||||||
|
|
||||||
// The running mode of the audio classifier. Default to the audio clips mode.
|
// The running mode of the audio classifier. Default to the audio clips mode.
|
||||||
// Audio classifier has two running modes:
|
// Audio classifier has two running modes:
|
||||||
|
@ -59,8 +59,9 @@ struct AudioClassifierOptions {
|
||||||
// The user-defined result callback for processing audio stream data.
|
// The user-defined result callback for processing audio stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::AUDIO_STREAM.
|
// to RunningMode::AUDIO_STREAM.
|
||||||
std::function<void(absl::StatusOr<ClassificationResult>)> result_callback =
|
std::function<void(
|
||||||
nullptr;
|
absl::StatusOr<components::containers::proto::ClassificationResult>)>
|
||||||
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Performs audio classification on audio clips or audio stream.
|
// Performs audio classification on audio clips or audio stream.
|
||||||
|
@ -132,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
||||||
// framed audio clip.
|
// framed audio clip.
|
||||||
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
// TODO: Use `sample_rate` in AudioClassifierOptions by default
|
||||||
// and makes `audio_sample_rate` optional.
|
// and makes `audio_sample_rate` optional.
|
||||||
absl::StatusOr<ClassificationResult> Classify(mediapipe::Matrix audio_clip,
|
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||||
double audio_sample_rate);
|
mediapipe::Matrix audio_clip, double audio_sample_rate);
|
||||||
|
|
||||||
// Sends audio data (a block in a continuous audio stream) to perform audio
|
// Sends audio data (a block in a continuous audio stream) to perform audio
|
||||||
// classification. Only use this method when the AudioClassifier is created
|
// classification. Only use this method when the AudioClassifier is created
|
||||||
|
|
|
@ -31,9 +31,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
|
@ -53,6 +53,7 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::GenericNode;
|
using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
|
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
|
||||||
constexpr char kAudioTag[] = "AUDIO";
|
constexpr char kAudioTag[] = "AUDIO";
|
||||||
|
@ -238,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
||||||
|
|
||||||
// Adds postprocessing calculators and connects them to the graph output.
|
// Adds postprocessing calculators and connects them to the graph output.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
"mediapipe.tasks.components.processors."
|
||||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
"ClassificationPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||||
model_resources, task_options.classifier_options(),
|
model_resources, task_options.classifier_options(),
|
||||||
&postprocessing.GetOptions<
|
&postprocessing
|
||||||
tasks::components::ClassificationPostprocessingOptions>()));
|
.GetOptions<components::processors::proto::
|
||||||
|
ClassificationPostprocessingGraphOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
// Time aggregation is only needed for performing audio classification on
|
// Time aggregation is only needed for performing audio classification on
|
||||||
|
|
|
@ -37,8 +37,8 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
#include "mediapipe/tasks/cc/audio/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
#include "mediapipe/tasks/cc/audio/utils/test_utils.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
#include "tensorflow/lite/core/shims/cc/shims_test_util.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
@ -49,6 +49,7 @@ namespace {
|
||||||
|
|
||||||
using ::absl::StatusOr;
|
using ::absl::StatusOr;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.audio.audio_classifier.proto;
|
package mediapipe.tasks.audio.audio_classifier.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message AudioClassifierGraphOptions {
|
message AudioClassifierGraphOptions {
|
||||||
|
@ -31,7 +31,7 @@ message AudioClassifierGraphOptions {
|
||||||
|
|
||||||
// Options for configuring the classifier behavior, such as score threshold,
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
// number of results, etc.
|
// number of results, etc.
|
||||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||||
|
|
||||||
// The default sample rate of the input audio. Must be set when the
|
// The default sample rate of the input audio. Must be set when the
|
||||||
// AudioClassifier is configured to process audio stream data.
|
// AudioClassifier is configured to process audio stream data.
|
||||||
|
|
|
@ -58,65 +58,6 @@ cc_library(
|
||||||
|
|
||||||
# TODO: Enable this test
|
# TODO: Enable this test
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "classifier_options",
|
|
||||||
srcs = ["classifier_options.cc"],
|
|
||||||
hdrs = ["classifier_options.h"],
|
|
||||||
deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "classification_postprocessing_options_proto",
|
|
||||||
srcs = ["classification_postprocessing_options.proto"],
|
|
||||||
deps = [
|
|
||||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
|
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
|
||||||
"//mediapipe/framework:calculator_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "classification_postprocessing",
|
|
||||||
srcs = ["classification_postprocessing.cc"],
|
|
||||||
hdrs = ["classification_postprocessing.h"],
|
|
||||||
deps = [
|
|
||||||
":classification_postprocessing_options_cc_proto",
|
|
||||||
"//mediapipe/calculators/core:split_vector_calculator",
|
|
||||||
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
|
||||||
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
|
||||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
|
||||||
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
|
||||||
"//mediapipe/framework/api2:builder",
|
|
||||||
"//mediapipe/framework/api2:packet",
|
|
||||||
"//mediapipe/framework/api2:port",
|
|
||||||
"//mediapipe/framework/formats:tensor",
|
|
||||||
"//mediapipe/tasks/cc:common",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
|
||||||
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
|
||||||
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
|
||||||
"//mediapipe/util:label_map_cc_proto",
|
|
||||||
"//mediapipe/util:label_map_util",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
|
||||||
"@com_google_absl//absl/status",
|
|
||||||
"@com_google_absl//absl/status:statusor",
|
|
||||||
"@com_google_absl//absl/strings",
|
|
||||||
"@com_google_absl//absl/strings:str_format",
|
|
||||||
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "embedder_options",
|
name = "embedder_options",
|
||||||
srcs = ["embedder_options.cc"],
|
srcs = ["embedder_options.cc"],
|
||||||
|
|
|
@ -37,8 +37,8 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:packet",
|
"//mediapipe/framework/api2:packet",
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:classification_cc_proto",
|
"//mediapipe/framework/formats:classification_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:category_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -128,7 +128,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:integral_types",
|
"//mediapipe/framework/port:integral_types",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,15 +25,15 @@
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/classification.pb.h"
|
#include "mediapipe/framework/formats/classification.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
namespace api2 {
|
||||||
|
|
||||||
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
|
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
|
||||||
using ::mediapipe::tasks::ClassificationResult;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::Classifications;
|
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||||
|
|
||||||
// Aggregates ClassificationLists into a single ClassificationResult that has
|
// Aggregates ClassificationLists into a single ClassificationResult that has
|
||||||
// 3 dimensions: (classification head, classification timestamp, classification
|
// 3 dimensions: (classification head, classification timestamp, classification
|
||||||
|
|
|
@ -17,12 +17,13 @@ limitations under the License.
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
|
|
||||||
// Specialized EndLoopCalculator for Tasks specific types.
|
// Specialized EndLoopCalculator for Tasks specific types.
|
||||||
namespace mediapipe::tasks {
|
namespace mediapipe::tasks {
|
||||||
|
|
||||||
typedef EndLoopCalculator<std::vector<ClassificationResult>>
|
typedef EndLoopCalculator<
|
||||||
|
std::vector<components::containers::proto::ClassificationResult>>
|
||||||
EndLoopClassificationResultCalculator;
|
EndLoopClassificationResultCalculator;
|
||||||
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);
|
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "category_proto",
|
||||||
|
srcs = ["category.proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "classifications_proto",
|
||||||
|
srcs = ["classifications.proto"],
|
||||||
|
deps = [
|
||||||
|
":category_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "embeddings_proto",
|
||||||
|
srcs = ["embeddings.proto"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "landmarks_detection_result_proto",
|
name = "landmarks_detection_result_proto",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
@ -29,8 +47,3 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/framework/formats:rect_proto",
|
"//mediapipe/framework/formats:rect_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "embeddings_proto",
|
|
||||||
srcs = ["embeddings.proto"],
|
|
||||||
)
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks;
|
package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
// A single classification result.
|
// A single classification result.
|
||||||
message Category {
|
message Category {
|
|
@ -15,9 +15,9 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks;
|
package mediapipe.tasks.components.containers.proto;
|
||||||
|
|
||||||
import "mediapipe/tasks/cc/components/containers/category.proto";
|
import "mediapipe/tasks/cc/components/containers/proto/category.proto";
|
||||||
|
|
||||||
// List of predicted categories with an optional timestamp.
|
// List of predicted categories with an optional timestamp.
|
||||||
message ClassificationEntry {
|
message ClassificationEntry {
|
64
mediapipe/tasks/cc/components/processors/BUILD
Normal file
64
mediapipe/tasks/cc/components/processors/BUILD
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "classifier_options",
|
||||||
|
srcs = ["classifier_options.cc"],
|
||||||
|
hdrs = ["classifier_options.h"],
|
||||||
|
deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "classification_postprocessing_graph",
|
||||||
|
srcs = ["classification_postprocessing_graph.cc"],
|
||||||
|
hdrs = ["classification_postprocessing_graph.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/calculators/core:split_vector_calculator",
|
||||||
|
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
|
||||||
|
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:builder",
|
||||||
|
"//mediapipe/framework/api2:packet",
|
||||||
|
"//mediapipe/framework/api2:port",
|
||||||
|
"//mediapipe/framework/formats:tensor",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
|
||||||
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
|
||||||
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
|
"//mediapipe/tasks/cc/metadata:metadata_extractor",
|
||||||
|
"//mediapipe/tasks/metadata:metadata_schema_cc",
|
||||||
|
"//mediapipe/util:label_map_cc_proto",
|
||||||
|
"//mediapipe/util:label_map_util",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/status:statusor",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
|
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
|
@ -37,9 +37,9 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
|
||||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp;
|
||||||
using ::mediapipe::api2::builder::GenericNode;
|
using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
|
||||||
using ::tflite::ProcessUnit;
|
using ::tflite::ProcessUnit;
|
||||||
|
@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS";
|
||||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||||
|
|
||||||
// Performs sanity checks on provided ClassifierOptions.
|
// Performs sanity checks on provided ClassifierOptions.
|
||||||
absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) {
|
absl::Status SanityCheckClassifierOptions(
|
||||||
|
const proto::ClassifierOptions& options) {
|
||||||
if (options.max_results() == 0) {
|
if (options.max_results() == 0) {
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
absl::StatusCode::kInvalidArgument,
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
@ -203,7 +205,7 @@ absl::StatusOr<float> GetScoreThreshold(
|
||||||
|
|
||||||
// Gets the category allowlist or denylist (if any) as a set of indices.
|
// Gets the category allowlist or denylist (if any) as a set of indices.
|
||||||
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
||||||
const ClassifierOptions& options, const LabelItems& label_items) {
|
const proto::ClassifierOptions& options, const LabelItems& label_items) {
|
||||||
absl::flat_hash_set<int> category_indices;
|
absl::flat_hash_set<int> category_indices;
|
||||||
// Exit early if no denylist/allowlist.
|
// Exit early if no denylist/allowlist.
|
||||||
if (options.category_denylist_size() == 0 &&
|
if (options.category_denylist_size() == 0 &&
|
||||||
|
@ -239,7 +241,7 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
|
||||||
|
|
||||||
absl::Status ConfigureScoreCalibrationIfAny(
|
absl::Status ConfigureScoreCalibrationIfAny(
|
||||||
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||||
ClassificationPostprocessingOptions* options) {
|
proto::ClassificationPostprocessingGraphOptions* options) {
|
||||||
const auto* tensor_metadata =
|
const auto* tensor_metadata =
|
||||||
metadata_extractor.GetOutputTensorMetadata(tensor_index);
|
metadata_extractor.GetOutputTensorMetadata(tensor_index);
|
||||||
if (tensor_metadata == nullptr) {
|
if (tensor_metadata == nullptr) {
|
||||||
|
@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
|
||||||
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
// Fills in the TensorsToClassificationCalculatorOptions based on the
|
||||||
// classifier options and the (optional) output tensor metadata.
|
// classifier options and the (optional) output tensor metadata.
|
||||||
absl::Status ConfigureTensorsToClassificationCalculator(
|
absl::Status ConfigureTensorsToClassificationCalculator(
|
||||||
const ClassifierOptions& options,
|
const proto::ClassifierOptions& options,
|
||||||
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
const ModelMetadataExtractor& metadata_extractor, int tensor_index,
|
||||||
TensorsToClassificationCalculatorOptions* calculator_options) {
|
TensorsToClassificationCalculatorOptions* calculator_options) {
|
||||||
const auto* tensor_metadata =
|
const auto* tensor_metadata =
|
||||||
|
@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::Status ConfigureClassificationPostprocessing(
|
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
const ModelResources& model_resources,
|
const ModelResources& model_resources,
|
||||||
const ClassifierOptions& classifier_options,
|
const proto::ClassifierOptions& classifier_options,
|
||||||
ClassificationPostprocessingOptions* options) {
|
proto::ClassificationPostprocessingGraphOptions* options) {
|
||||||
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
|
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
|
||||||
ASSIGN_OR_RETURN(const auto heads_properties,
|
ASSIGN_OR_RETURN(const auto heads_properties,
|
||||||
GetClassificationHeadsProperties(model_resources));
|
GetClassificationHeadsProperties(model_resources));
|
||||||
|
@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing(
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
|
// A "ClassificationPostprocessingGraph" converts raw tensors into
|
||||||
// raw tensors into ClassificationResult objects.
|
// ClassificationResult objects.
|
||||||
// - Accepts CPU input tensors.
|
// - Accepts CPU input tensors.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
|
@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing(
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATION_RESULT - ClassificationResult
|
||||||
// The output aggregated classification results.
|
// The output aggregated classification results.
|
||||||
//
|
//
|
||||||
// The recommended way of using this subgraph is through the GraphBuilder API
|
// The recommended way of using this graph is through the GraphBuilder API
|
||||||
// using the 'ConfigureClassificationPostprocessing()' function. See header file
|
// using the 'ConfigureClassificationPostprocessingGraph()' function. See header
|
||||||
// for more details.
|
// file for more details.
|
||||||
class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
|
||||||
public:
|
public:
|
||||||
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
|
||||||
mediapipe::SubgraphContext* sc) override {
|
mediapipe::SubgraphContext* sc) override {
|
||||||
|
@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto classification_result_out,
|
auto classification_result_out,
|
||||||
BuildClassificationPostprocessing(
|
BuildClassificationPostprocessing(
|
||||||
sc->Options<ClassificationPostprocessingOptions>(),
|
sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
graph[Input<std::vector<Tensor>>(kTensorsTag)],
|
||||||
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
|
||||||
classification_result_out >>
|
classification_result_out >>
|
||||||
|
@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Adds an on-device classification postprocessing subgraph into the provided
|
// Adds an on-device classification postprocessing graph into the provided
|
||||||
// builder::Graph instance. The classification postprocessing subgraph takes
|
// builder::Graph instance. The classification postprocessing graph takes
|
||||||
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output
|
||||||
// stream containing the output classification results (ClassificationResult).
|
// stream containing the output classification results (ClassificationResult).
|
||||||
//
|
//
|
||||||
// options: the on-device ClassificationPostprocessingOptions.
|
// options: the on-device ClassificationPostprocessingGraphOptions.
|
||||||
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
|
||||||
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
|
||||||
// timestamps that a single ClassificationResult should aggregate.
|
// timestamps that a single ClassificationResult should aggregate.
|
||||||
// graph: the mediapipe builder::Graph instance to be updated.
|
// graph: the mediapipe builder::Graph instance to be updated.
|
||||||
absl::StatusOr<Source<ClassificationResult>>
|
absl::StatusOr<Source<ClassificationResult>>
|
||||||
BuildClassificationPostprocessing(
|
BuildClassificationPostprocessing(
|
||||||
const ClassificationPostprocessingOptions& options,
|
const proto::ClassificationPostprocessingGraphOptions& options,
|
||||||
Source<std::vector<Tensor>> tensors_in,
|
Source<std::vector<Tensor>> tensors_in,
|
||||||
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
|
||||||
const int num_heads = options.tensors_to_classifications_options_size();
|
const int num_heads = options.tensors_to_classifications_options_size();
|
||||||
|
@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
|
||||||
kClassificationResultTag)];
|
kClassificationResultTag)];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
REGISTER_MEDIAPIPE_GRAPH(
|
|
||||||
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
|
|
||||||
|
|
||||||
|
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors::
|
||||||
|
ClassificationPostprocessingGraph); // NOLINT
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -13,32 +13,33 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
||||||
|
|
||||||
#include "absl/status/status.h"
|
#include "absl/status/status.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
// Configures a ClassificationPostprocessing subgraph using the provided model
|
// Configures a ClassificationPostprocessingGraph using the provided model
|
||||||
// resources and ClassifierOptions.
|
// resources and ClassifierOptions.
|
||||||
// - Accepts CPU input tensors.
|
// - Accepts CPU input tensors.
|
||||||
//
|
//
|
||||||
// Example usage:
|
// Example usage:
|
||||||
//
|
//
|
||||||
// auto& postprocessing =
|
// auto& postprocessing =
|
||||||
// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
// graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph");
|
||||||
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||||
// model_resources,
|
// model_resources,
|
||||||
// classifier_options,
|
// classifier_options,
|
||||||
// &preprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
// &preprocessing.GetOptions<ClassificationPostprocessingGraphOptions>()));
|
||||||
//
|
//
|
||||||
// The resulting ClassificationPostprocessing subgraph has the following I/O:
|
// The resulting ClassificationPostprocessingGraph has the following I/O:
|
||||||
// Inputs:
|
// Inputs:
|
||||||
// TENSORS - std::vector<Tensor>
|
// TENSORS - std::vector<Tensor>
|
||||||
// The output tensors of an InferenceCalculator.
|
// The output tensors of an InferenceCalculator.
|
||||||
|
@ -49,13 +50,14 @@ namespace components {
|
||||||
// Outputs:
|
// Outputs:
|
||||||
// CLASSIFICATION_RESULT - ClassificationResult
|
// CLASSIFICATION_RESULT - ClassificationResult
|
||||||
// The output aggregated classification results.
|
// The output aggregated classification results.
|
||||||
absl::Status ConfigureClassificationPostprocessing(
|
absl::Status ConfigureClassificationPostprocessingGraph(
|
||||||
const tasks::core::ModelResources& model_resources,
|
const tasks::core::ModelResources& model_resources,
|
||||||
const tasks::components::proto::ClassifierOptions& classifier_options,
|
const proto::ClassifierOptions& classifier_options,
|
||||||
ClassificationPostprocessingOptions* options);
|
proto::ClassificationPostprocessingGraphOptions* options);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -42,9 +42,9 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/util/label_map.pb.h"
|
#include "mediapipe/util/label_map.pb.h"
|
||||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::api2::Input;
|
using ::mediapipe::api2::Input;
|
||||||
|
@ -60,7 +61,7 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
using ::mediapipe::tasks::components::proto::ClassifierOptions;
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::core::ModelResources;
|
using ::mediapipe::tasks::core::ModelResources;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::proto::Approximately;
|
using ::testing::proto::Approximately;
|
||||||
|
@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
options_in.set_max_results(0);
|
options_in.set_max_results(0);
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
auto status = ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out);
|
*model_resources, options_in, &options_out);
|
||||||
|
|
||||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
|
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
|
||||||
|
@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
options_in.add_category_allowlist("foo");
|
options_in.add_category_allowlist("foo");
|
||||||
options_in.add_category_denylist("bar");
|
options_in.add_category_denylist("bar");
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
auto status = ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out);
|
*model_resources, options_in, &options_out);
|
||||||
|
|
||||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
|
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
|
||||||
|
@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
options_in.add_category_allowlist("foo");
|
options_in.add_category_allowlist("foo");
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
auto status = ConfigureClassificationPostprocessing(*model_resources,
|
auto status = ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out);
|
*model_resources, options_in, &options_out);
|
||||||
|
|
||||||
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
|
||||||
EXPECT_THAT(
|
EXPECT_THAT(
|
||||||
|
@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(score_calibration_options: []
|
R"pb(score_calibration_options: []
|
||||||
|
@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
options_in.set_max_results(3);
|
options_in.set_max_results(3);
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(score_calibration_options: []
|
R"pb(score_calibration_options: []
|
||||||
|
@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
options_in.set_score_threshold(0.5);
|
options_in.set_score_threshold(0.5);
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
|
|
||||||
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
EXPECT_THAT(options_out, Approximately(EqualsProto(
|
||||||
R"pb(score_calibration_options: []
|
R"pb(score_calibration_options: []
|
||||||
|
@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
|
|
||||||
// Check label map size and two first elements.
|
// Check label map size and two first elements.
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
|
@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
options_in.add_category_allowlist("tench");
|
options_in.add_category_allowlist("tench");
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
|
|
||||||
// Clear label map and compare the rest of the options.
|
// Clear label map and compare the rest of the options.
|
||||||
options_out.mutable_tensors_to_classifications_options(0)
|
options_out.mutable_tensors_to_classifications_options(0)
|
||||||
|
@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
options_in.add_category_denylist("background");
|
options_in.add_category_denylist("background");
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
|
|
||||||
// Clear label map and compare the rest of the options.
|
// Clear label map and compare the rest of the options.
|
||||||
options_out.mutable_tensors_to_classifications_options(0)
|
options_out.mutable_tensors_to_classifications_options(0)
|
||||||
|
@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(
|
CreateModelResourcesForModel(
|
||||||
kQuantizedImageClassifierWithDummyScoreCalibration));
|
kQuantizedImageClassifierWithDummyScoreCalibration));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
|
|
||||||
// Check label map size and two first elements.
|
// Check label map size and two first elements.
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
|
@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto model_resources,
|
auto model_resources,
|
||||||
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
|
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
|
||||||
ClassifierOptions options_in;
|
proto::ClassifierOptions options_in;
|
||||||
|
|
||||||
ClassificationPostprocessingOptions options_out;
|
proto::ClassificationPostprocessingGraphOptions options_out;
|
||||||
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources,
|
MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
|
||||||
options_in, &options_out));
|
*model_resources, options_in, &options_out));
|
||||||
// Check label maps sizes and first two elements.
|
// Check label maps sizes and first two elements.
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
options_out.tensors_to_classifications_options(0).label_items_size(),
|
options_out.tensors_to_classifications_options(0).label_items_size(),
|
||||||
|
@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
|
||||||
class PostprocessingTest : public tflite_shims::testing::Test {
|
class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
absl::StatusOr<OutputStreamPoller> BuildGraph(
|
||||||
absl::string_view model_name, const ClassifierOptions& options,
|
absl::string_view model_name, const proto::ClassifierOptions& options,
|
||||||
bool connect_timestamps = false) {
|
bool connect_timestamps = false) {
|
||||||
ASSIGN_OR_RETURN(auto model_resources,
|
ASSIGN_OR_RETURN(auto model_resources,
|
||||||
CreateModelResourcesForModel(model_name));
|
CreateModelResourcesForModel(model_name));
|
||||||
|
|
||||||
Graph graph;
|
Graph graph;
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
"mediapipe.tasks.components.processors."
|
||||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
"ClassificationPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
|
||||||
*model_resources, options,
|
*model_resources, options,
|
||||||
&postprocessing.GetOptions<ClassificationPostprocessingOptions>()));
|
&postprocessing
|
||||||
|
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
|
||||||
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
|
||||||
postprocessing.In(kTensorsTag);
|
postprocessing.In(kTensorsTag);
|
||||||
if (connect_timestamps) {
|
if (connect_timestamps) {
|
||||||
|
@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(3);
|
options.set_max_results(3);
|
||||||
options.set_score_threshold(0.5);
|
options.set_score_threshold(0.5);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
|
@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(3);
|
options.set_max_results(3);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
|
||||||
|
@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(3);
|
options.set_max_results(3);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto poller,
|
auto poller,
|
||||||
|
@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(2);
|
options.set_max_results(2);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto poller,
|
auto poller,
|
||||||
|
@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
|
||||||
|
|
||||||
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
// Build graph.
|
// Build graph.
|
||||||
ClassifierOptions options;
|
proto::ClassifierOptions options;
|
||||||
options.set_max_results(2);
|
options.set_max_results(2);
|
||||||
MP_ASSERT_OK_AND_ASSIGN(
|
MP_ASSERT_OK_AND_ASSIGN(
|
||||||
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
|
||||||
|
@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||||
ClassifierOptions* options) {
|
ClassifierOptions* options) {
|
||||||
tasks::components::proto::ClassifierOptions options_proto;
|
proto::ClassifierOptions options_proto;
|
||||||
options_proto.set_display_names_locale(options->display_names_locale);
|
options_proto.set_display_names_locale(options->display_names_locale);
|
||||||
options_proto.set_max_results(options->max_results);
|
options_proto.set_max_results(options->max_results);
|
||||||
options_proto.set_score_threshold(options->score_threshold);
|
options_proto.set_score_threshold(options->score_threshold);
|
||||||
|
@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||||
return options_proto;
|
return options_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
|
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
||||||
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
#define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
||||||
|
|
||||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace tasks {
|
namespace tasks {
|
||||||
namespace components {
|
namespace components {
|
||||||
|
namespace processors {
|
||||||
|
|
||||||
// Classifier options for MediaPipe C++ classification Tasks.
|
// Classifier options for MediaPipe C++ classification Tasks.
|
||||||
struct ClassifierOptions {
|
struct ClassifierOptions {
|
||||||
|
@ -49,11 +50,12 @@ struct ClassifierOptions {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Converts a ClassifierOptions to a ClassifierOptionsProto.
|
// Converts a ClassifierOptions to a ClassifierOptionsProto.
|
||||||
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
proto::ClassifierOptions ConvertClassifierOptionsToProto(
|
||||||
ClassifierOptions* classifier_options);
|
ClassifierOptions* classifier_options);
|
||||||
|
|
||||||
|
} // namespace processors
|
||||||
} // namespace components
|
} // namespace components
|
||||||
} // namespace tasks
|
} // namespace tasks
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_
|
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
|
|
@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "category_proto",
|
name = "classifier_options_proto",
|
||||||
srcs = ["category.proto"],
|
srcs = ["classifier_options.proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "classifications_proto",
|
name = "classification_postprocessing_graph_options_proto",
|
||||||
srcs = ["classifications.proto"],
|
srcs = ["classification_postprocessing_graph_options.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
":category_proto",
|
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
|
@ -15,16 +15,16 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
|
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
|
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
|
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
|
||||||
|
|
||||||
message ClassificationPostprocessingOptions {
|
message ClassificationPostprocessingGraphOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional ClassificationPostprocessingOptions ext = 460416950;
|
optional ClassificationPostprocessingGraphOptions ext = 460416950;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional mapping between output tensor index and corresponding score
|
// Optional mapping between output tensor index and corresponding score
|
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
syntax = "proto2";
|
syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.components.proto;
|
package mediapipe.tasks.components.processors.proto;
|
||||||
|
|
||||||
// Shared options used by all classification tasks.
|
// Shared options used by all classification tasks.
|
||||||
message ClassifierOptions {
|
message ClassifierOptions {
|
|
@ -23,11 +23,6 @@ mediapipe_proto_library(
|
||||||
srcs = ["segmenter_options.proto"],
|
srcs = ["segmenter_options.proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
|
||||||
name = "classifier_options_proto",
|
|
||||||
srcs = ["classifier_options.proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "embedder_options_proto",
|
name = "embedder_options_proto",
|
||||||
srcs = ["embedder_options.proto"],
|
srcs = ["embedder_options.proto"],
|
||||||
|
|
|
@ -54,10 +54,10 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
"//mediapipe/framework/formats:tensor",
|
"//mediapipe/framework/formats:tensor",
|
||||||
"//mediapipe/tasks/cc:common",
|
"//mediapipe/tasks/cc:common",
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
|
|
@ -27,9 +27,9 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/matrix.h"
|
#include "mediapipe/framework/formats/matrix.h"
|
||||||
#include "mediapipe/framework/formats/tensor.h"
|
#include "mediapipe/framework/formats/tensor.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
|
@ -49,6 +49,7 @@ using ::mediapipe::api2::Input;
|
||||||
using ::mediapipe::api2::Output;
|
using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
|
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
|
||||||
HandGestureRecognizerSubgraphOptions;
|
HandGestureRecognizerSubgraphOptions;
|
||||||
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
|
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
|
||||||
|
@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
|
||||||
auto inference_output_tensors = inference.Out(kTensorsTag);
|
auto inference_output_tensors = inference.Out(kTensorsTag);
|
||||||
|
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
"mediapipe.tasks.components.processors."
|
||||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
"ClassificationPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||||
model_resources, graph_options.classifier_options(),
|
model_resources, graph_options.classifier_options(),
|
||||||
&postprocessing.GetOptions<
|
&postprocessing
|
||||||
tasks::components::ClassificationPostprocessingOptions>()));
|
.GetOptions<components::processors::proto::
|
||||||
|
ClassificationPostprocessingGraphOptions>()));
|
||||||
inference_output_tensors >> postprocessing.In(kTensorsTag);
|
inference_output_tensors >> postprocessing.In(kTensorsTag);
|
||||||
auto classification_result =
|
auto classification_result =
|
||||||
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
|
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];
|
||||||
|
|
|
@ -26,7 +26,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,5 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
|
package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message HandGestureRecognizerSubgraphOptions {
|
message HandGestureRecognizerSubgraphOptions {
|
||||||
|
@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions {
|
||||||
|
|
||||||
// Options for configuring the gesture classifier behavior, such as score
|
// Options for configuring the gesture classifier behavior, such as score
|
||||||
// threshold, number of results, etc.
|
// threshold, number of results, etc.
|
||||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||||
|
|
||||||
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
|
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
|
||||||
// considered tracked successfully
|
// considered tracked successfully
|
||||||
|
|
|
@ -26,11 +26,11 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:port",
|
"//mediapipe/framework/api2:port",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
|
||||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing",
|
"//mediapipe/tasks/cc/components:image_preprocessing",
|
||||||
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
|
||||||
|
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:model_resources",
|
"//mediapipe/tasks/cc/core:model_resources",
|
||||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||||
|
@ -50,9 +50,9 @@ cc_library(
|
||||||
"//mediapipe/framework/api2:builder",
|
"//mediapipe/framework/api2:builder",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components:classifier_options",
|
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
|
||||||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
"//mediapipe/tasks/cc/components/processors:classifier_options",
|
||||||
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/core:base_options",
|
"//mediapipe/tasks/cc/core:base_options",
|
||||||
"//mediapipe/tasks/cc/core:task_runner",
|
"//mediapipe/tasks/cc/core:task_runner",
|
||||||
"//mediapipe/tasks/cc/core:utils",
|
"//mediapipe/tasks/cc/core:utils",
|
||||||
|
|
|
@ -26,9 +26,9 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/timestamp.h"
|
#include "mediapipe/framework/timestamp.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
|
||||||
|
@ -56,6 +56,7 @@ constexpr char kSubgraphTypeName[] =
|
||||||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||||
|
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
using ::mediapipe::tasks::core::PacketMap;
|
using ::mediapipe::tasks::core::PacketMap;
|
||||||
|
|
||||||
// Builds a NormalizedRect covering the entire image.
|
// Builds a NormalizedRect covering the entire image.
|
||||||
|
@ -107,8 +108,8 @@ ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
|
||||||
options_proto->mutable_base_options()->set_use_stream_mode(
|
options_proto->mutable_base_options()->set_use_stream_mode(
|
||||||
options->running_mode != core::RunningMode::IMAGE);
|
options->running_mode != core::RunningMode::IMAGE);
|
||||||
auto classifier_options_proto =
|
auto classifier_options_proto =
|
||||||
std::make_unique<tasks::components::proto::ClassifierOptions>(
|
std::make_unique<components::processors::proto::ClassifierOptions>(
|
||||||
components::ConvertClassifierOptionsToProto(
|
components::processors::ConvertClassifierOptionsToProto(
|
||||||
&(options->classifier_options)));
|
&(options->classifier_options)));
|
||||||
options_proto->mutable_classifier_options()->Swap(
|
options_proto->mutable_classifier_options()->Swap(
|
||||||
classifier_options_proto.get());
|
classifier_options_proto.get());
|
||||||
|
|
|
@ -23,8 +23,8 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/processors/classifier_options.h"
|
||||||
#include "mediapipe/tasks/cc/core/base_options.h"
|
#include "mediapipe/tasks/cc/core/base_options.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
|
@ -51,12 +51,14 @@ struct ImageClassifierOptions {
|
||||||
|
|
||||||
// Options for configuring the classifier behavior, such as score threshold,
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
// number of results, etc.
|
// number of results, etc.
|
||||||
components::ClassifierOptions classifier_options;
|
components::processors::ClassifierOptions classifier_options;
|
||||||
|
|
||||||
// The user-defined result callback for processing live stream data.
|
// The user-defined result callback for processing live stream data.
|
||||||
// The result callback should only be specified when the running mode is set
|
// The result callback should only be specified when the running mode is set
|
||||||
// to RunningMode::LIVE_STREAM.
|
// to RunningMode::LIVE_STREAM.
|
||||||
std::function<void(absl::StatusOr<ClassificationResult>, const Image&, int64)>
|
std::function<void(
|
||||||
|
absl::StatusOr<components::containers::proto::ClassificationResult>,
|
||||||
|
const Image&, int64)>
|
||||||
result_callback = nullptr;
|
result_callback = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -113,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA.
|
// The image can be of any size with format RGB or RGBA.
|
||||||
// TODO: describe exact preprocessing steps once
|
// TODO: describe exact preprocessing steps once
|
||||||
// YUVToImageCalculator is integrated.
|
// YUVToImageCalculator is integrated.
|
||||||
absl::StatusOr<ClassificationResult> Classify(
|
absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
|
||||||
mediapipe::Image image,
|
mediapipe::Image image,
|
||||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||||
|
|
||||||
|
@ -127,8 +129,8 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
||||||
// The image can be of any size with format RGB or RGBA. It's required to
|
// The image can be of any size with format RGB or RGBA. It's required to
|
||||||
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
// provide the video frame's timestamp (in milliseconds). The input timestamps
|
||||||
// must be monotonically increasing.
|
// must be monotonically increasing.
|
||||||
absl::StatusOr<ClassificationResult> ClassifyForVideo(
|
absl::StatusOr<components::containers::proto::ClassificationResult>
|
||||||
mediapipe::Image image, int64 timestamp_ms,
|
ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
|
||||||
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
|
||||||
|
|
||||||
// Sends live image data to image classification, and the results will be
|
// Sends live image data to image classification, and the results will be
|
||||||
|
|
|
@ -22,11 +22,11 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/image.h"
|
#include "mediapipe/framework/formats/image.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing.h"
|
||||||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
|
||||||
|
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||||
|
@ -43,6 +43,7 @@ using ::mediapipe::api2::Output;
|
||||||
using ::mediapipe::api2::builder::GenericNode;
|
using ::mediapipe::api2::builder::GenericNode;
|
||||||
using ::mediapipe::api2::builder::Graph;
|
using ::mediapipe::api2::builder::Graph;
|
||||||
using ::mediapipe::api2::builder::Source;
|
using ::mediapipe::api2::builder::Source;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
|
||||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||||
|
|
||||||
|
@ -152,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
||||||
|
|
||||||
// Adds postprocessing calculators and connects them to the graph output.
|
// Adds postprocessing calculators and connects them to the graph output.
|
||||||
auto& postprocessing = graph.AddNode(
|
auto& postprocessing = graph.AddNode(
|
||||||
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
|
"mediapipe.tasks.components.processors."
|
||||||
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
|
"ClassificationPostprocessingGraph");
|
||||||
|
MP_RETURN_IF_ERROR(
|
||||||
|
components::processors::ConfigureClassificationPostprocessingGraph(
|
||||||
model_resources, task_options.classifier_options(),
|
model_resources, task_options.classifier_options(),
|
||||||
&postprocessing.GetOptions<
|
&postprocessing
|
||||||
tasks::components::ClassificationPostprocessingOptions>()));
|
.GetOptions<components::processors::proto::
|
||||||
|
ClassificationPostprocessingGraphOptions>()));
|
||||||
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
|
||||||
|
|
||||||
// Outputs the aggregated classification result as the subgraph output
|
// Outputs the aggregated classification result as the subgraph output
|
||||||
|
|
|
@ -32,8 +32,8 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/tasks/cc/common.h"
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/category.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
|
||||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||||
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
#include "mediapipe/tasks/cc/vision/utils/image_utils.h"
|
||||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||||
|
@ -48,6 +48,9 @@ namespace image_classifier {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::file::JoinPath;
|
using ::mediapipe::file::JoinPath;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
|
||||||
|
using ::mediapipe::tasks::components::containers::proto::Classifications;
|
||||||
using ::testing::HasSubstr;
|
using ::testing::HasSubstr;
|
||||||
using ::testing::Optional;
|
using ::testing::Optional;
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ mediapipe_proto_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
|
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,7 +18,7 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.image_classifier.proto;
|
package mediapipe.tasks.vision.image_classifier.proto;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator.proto";
|
import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
|
|
||||||
message ImageClassifierGraphOptions {
|
message ImageClassifierGraphOptions {
|
||||||
|
@ -31,5 +31,5 @@ message ImageClassifierGraphOptions {
|
||||||
|
|
||||||
// Options for configuring the classifier behavior, such as score threshold,
|
// Options for configuring the classifier behavior, such as score threshold,
|
||||||
// number of results, etc.
|
// number of results, etc.
|
||||||
optional components.proto.ClassifierOptions classifier_options = 2;
|
optional components.processors.proto.ClassifierOptions classifier_options = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ py_library(
|
||||||
name = "category",
|
name = "category",
|
||||||
srcs = ["category.py"],
|
srcs = ["category.py"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/tasks/cc/components/containers:category_py_pb2",
|
"//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
|
||||||
"//mediapipe/tasks/python/core:optional_dependencies",
|
"//mediapipe/tasks/python/core:optional_dependencies",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from mediapipe.tasks.cc.components.containers import category_pb2
|
from mediapipe.tasks.cc.components.containers.proto import category_pb2
|
||||||
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
from mediapipe.tasks.python.core.optional_dependencies import doc_controls
|
||||||
|
|
||||||
_CategoryProto = category_pb2.Category
|
_CategoryProto = category_pb2.Category
|
||||||
|
|
Loading…
Reference in New Issue
Block a user