Refactor classifiers public APIs for namespace and naming consistency.

PiperOrigin-RevId: 477705647
This commit is contained in:
MediaPipe Team 2022-09-29 06:03:59 -07:00 committed by Copybara-Service
parent 554e2a9d69
commit a8ca669f05
14 changed files with 80 additions and 66 deletions

View File

@ -33,7 +33,7 @@ cc_library(
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
@ -60,7 +60,7 @@ cc_library(
":audio_classifier_graph",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode",

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/classifier_options.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
@ -33,6 +33,8 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace audio {
namespace audio_classifier {
namespace {
constexpr char kAudioStreamName[] = "audio_in";
@ -42,16 +44,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.audio.AudioClassifierGraph";
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using AudioClassifierOptionsProto =
audio_classifier::proto::AudioClassifierOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.audio.AudioClassifierGraph".
// type "AudioClassifierGraph".
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<AudioClassifierOptionsProto> options_proto) {
std::unique_ptr<proto::AudioClassifierGraphOptions> options_proto) {
api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kSubgraphTypeName);
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
@ -59,7 +58,8 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
subgraph.In(kSampleRateTag);
}
subgraph.GetOptions<AudioClassifierOptionsProto>().Swap(options_proto.get());
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get());
subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag);
@ -67,10 +67,10 @@ CalculatorGraphConfig CreateGraphConfig(
}
// Converts the user-facing AudioClassifierOptions struct to the internal
// AudioClassifierOptions proto.
std::unique_ptr<AudioClassifierOptionsProto>
// AudioClassifierGraphOptions proto.
std::unique_ptr<proto::AudioClassifierGraphOptions>
ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
auto options_proto = std::make_unique<AudioClassifierOptionsProto>();
auto options_proto = std::make_unique<proto::AudioClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
@ -119,7 +119,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
};
}
return core::AudioTaskApiFactory::Create<AudioClassifier,
AudioClassifierOptionsProto>(
proto::AudioClassifierGraphOptions>(
CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback));
@ -140,6 +140,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace audio_classifier
} // namespace audio
} // namespace tasks
} // namespace mediapipe

View File

@ -30,6 +30,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace audio {
namespace audio_classifier {
// The options for configuring a mediapipe audio classifier task.
struct AudioClassifierOptions {
@ -162,6 +163,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
absl::Status Close() { return runner_->Close(); }
};
} // namespace audio_classifier
} // namespace audio
} // namespace tasks
} // namespace mediapipe

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace audio {
namespace audio_classifier {
namespace {
@ -60,10 +61,9 @@ constexpr char kPacketTag[] = "PACKET";
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS";
using AudioClassifierOptionsProto =
audio_classifier::proto::AudioClassifierOptions;
absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) {
absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() &&
!options.has_default_input_audio_sample_rate()) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
@ -111,7 +111,7 @@ void ConfigureAudioToTensorCalculator(
} // namespace
// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification.
// An "AudioClassifierGraph" performs audio classification.
// - Accepts CPU audio buffer and outputs classification results on CPU.
//
// Inputs:
@ -129,12 +129,12 @@ void ConfigureAudioToTensorCalculator(
//
// Example:
// node {
// calculator: "mediapipe.tasks.audio.AudioClassifierGraph"
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
// input_stream: "AUDIO:audio_in"
// input_stream: "SAMPLE_RATE:sample_rate_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
// {
// base_options {
// model_asset {
@ -152,16 +152,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<AudioClassifierOptionsProto>(sc));
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
Graph graph;
const bool use_stream_mode = sc->Options<AudioClassifierOptionsProto>()
const bool use_stream_mode =
sc->Options<proto::AudioClassifierGraphOptions>()
.base_options()
.use_stream_mode();
ASSIGN_OR_RETURN(
auto classification_result_out,
BuildAudioClassificationTask(
sc->Options<AudioClassifierOptionsProto>(), *model_resources,
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)],
use_stream_mode
? absl::nullopt
@ -178,14 +180,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// buffer (mediapipe::Matrix) and the corresponding sample rate (double) as
// the inputs and returns one classification result per input audio buffer.
//
// task_options: the mediapipe tasks AudioClassifierOptions proto.
// task_options: the mediapipe tasks AudioClassifierGraphOptions proto.
// model_resources: the ModelSources object initialized from an audio
// classifier model file with model metadata.
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
// sample_rate_in: (double) optional stream of the input audio sample rate.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
const AudioClassifierOptionsProto& task_options,
const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
@ -257,8 +259,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
}
};
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph);
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph);
} // namespace audio_classifier
} // namespace audio
} // namespace tasks
} // namespace mediapipe

View File

@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace audio {
namespace audio_classifier {
namespace {
using ::absl::StatusOr;
@ -557,6 +558,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
}
} // namespace
} // namespace audio_classifier
} // namespace audio
} // namespace tasks
} // namespace mediapipe

View File

@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "audio_classifier_options_proto",
srcs = ["audio_classifier_options.proto"],
name = "audio_classifier_graph_options_proto",
srcs = ["audio_classifier_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",

View File

@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message AudioClassifierOptions {
message AudioClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional AudioClassifierOptions ext = 451755788;
optional AudioClassifierGraphOptions ext = 451755788;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.

View File

@ -33,7 +33,7 @@ cc_library(
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
@ -61,7 +61,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],

View File

@ -36,11 +36,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_classifier {
namespace {
@ -52,12 +53,10 @@ constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageClassifierGraph";
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::core::PacketMap;
using ImageClassifierOptionsProto =
image_classifier::proto::ImageClassifierOptions;
// Builds a NormalizedRect covering the entire image.
NormalizedRect BuildFullImageNormRect() {
@ -70,17 +69,17 @@ NormalizedRect BuildFullImageNormRect() {
}
// Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the
// live stream mode, a "FlowLimiterCalculator" will be added to limit the number
// of frames in flight.
// type "ImageClassifierGraph". If the task is running in the live stream mode,
// a "FlowLimiterCalculator" will be added to limit the number of frames in
// flight.
CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageClassifierOptionsProto> options_proto,
std::unique_ptr<proto::ImageClassifierGraphOptions> options_proto,
bool enable_flow_limiting) {
api2::builder::Graph graph;
graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectName);
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageClassifierOptionsProto>().Swap(
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
options_proto.get());
task_subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >>
@ -98,10 +97,10 @@ CalculatorGraphConfig CreateGraphConfig(
}
// Converts the user-facing ImageClassifierOptions struct to the internal
// ImageClassifierOptions proto.
std::unique_ptr<ImageClassifierOptionsProto>
// ImageClassifierGraphOptions proto.
std::unique_ptr<proto::ImageClassifierGraphOptions>
ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
auto options_proto = std::make_unique<ImageClassifierOptionsProto>();
auto options_proto = std::make_unique<proto::ImageClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get());
@ -145,7 +144,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
};
}
return core::VisionTaskApiFactory::Create<ImageClassifier,
ImageClassifierOptionsProto>(
proto::ImageClassifierGraphOptions>(
CreateGraphConfig(
std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM),
@ -214,6 +213,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
}
} // namespace image_classifier
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -32,6 +32,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_classifier {
// The options for configuring a Mediapipe image classifier task.
struct ImageClassifierOptions {
@ -161,6 +162,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
absl::Status Close() { return runner_->Close(); }
};
} // namespace image_classifier
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -29,11 +29,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_classifier {
namespace {
@ -42,8 +43,6 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source;
using ImageClassifierOptionsProto =
image_classifier::proto::ImageClassifierOptions;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
@ -61,8 +60,7 @@ struct ImageClassifierOutputStreams {
} // namespace
// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image
// classification.
// An "ImageClassifierGraph" performs image classification.
// - Accepts CPU input images and outputs classifications on CPU.
//
// Inputs:
@ -80,12 +78,12 @@ struct ImageClassifierOutputStreams {
//
// Example:
// node {
// calculator: "mediapipe.tasks.vision.ImageClassifierGraph"
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
// input_stream: "IMAGE:image_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// output_stream: "IMAGE:image_out"
// options {
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext]
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
// {
// base_options {
// model_asset {
@ -104,13 +102,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources,
CreateModelResources<ImageClassifierOptionsProto>(sc));
ASSIGN_OR_RETURN(
const auto* model_resources,
CreateModelResources<proto::ImageClassifierGraphOptions>(sc));
Graph graph;
ASSIGN_OR_RETURN(
auto output_streams,
BuildImageClassificationTask(
sc->Options<ImageClassifierOptionsProto>(), *model_resources,
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.classification_result >>
@ -125,13 +124,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// (mediapipe::Image) as input and returns one classification result per input
// image.
//
// task_options: the mediapipe tasks ImageClassifierOptions.
// task_options: the mediapipe tasks ImageClassifierGraphOptions.
// model_resources: the ModelSources object initialized from an image
// classification model file with model metadata.
// image_in: (mediapipe::Image) stream to run classification on.
// graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask(
const ImageClassifierOptionsProto& task_options,
const proto::ImageClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Adds preprocessing calculators and connects them to the graph input image
@ -168,8 +167,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
/*image=*/preprocessing[Output<Image>(kImageTag)]};
}
};
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph);
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph);
} // namespace image_classifier
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe {
namespace tasks {
namespace vision {
namespace image_classifier {
namespace {
using ::mediapipe::file::JoinPath;
@ -814,6 +815,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
}
} // namespace
} // namespace image_classifier
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
mediapipe_proto_library(
name = "image_classifier_options_proto",
srcs = ["image_classifier_options.proto"],
name = "image_classifier_graph_options_proto",
srcs = ["image_classifier_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",

View File

@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageClassifierOptions {
message ImageClassifierGraphOptions {
extend mediapipe.CalculatorOptions {
optional ImageClassifierOptions ext = 456383383;
optional ImageClassifierGraphOptions ext = 456383383;
}
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc.