Refactor classifiers public APIs for namespace and naming consistency.
PiperOrigin-RevId: 477705647
This commit is contained in:
parent
554e2a9d69
commit
a8ca669f05
|
@ -33,7 +33,7 @@ cc_library(
|
|||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/tasks/cc:common",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing",
|
||||
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
|
||||
|
@ -60,7 +60,7 @@ cc_library(
|
|||
":audio_classifier_graph",
|
||||
"//mediapipe/framework/api2:builder",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
|
||||
"//mediapipe/tasks/cc/audio/core:base_audio_task_api",
|
||||
"//mediapipe/tasks/cc/audio/core:running_mode",
|
||||
|
|
|
@ -22,7 +22,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/builder.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/components/classifier_options.h"
|
||||
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
|
||||
|
@ -33,6 +33,8 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kAudioStreamName[] = "audio_in";
|
||||
|
@ -42,16 +44,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
|
|||
constexpr char kSampleRateName[] = "sample_rate_in";
|
||||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.audio.AudioClassifierGraph";
|
||||
"mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using AudioClassifierOptionsProto =
|
||||
audio_classifier::proto::AudioClassifierOptions;
|
||||
|
||||
// Creates a MediaPipe graph config that only contains a single subgraph node of
|
||||
// "mediapipe.tasks.audio.AudioClassifierGraph".
|
||||
// type "AudioClassifierGraph".
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<AudioClassifierOptionsProto> options_proto) {
|
||||
std::unique_ptr<proto::AudioClassifierGraphOptions> options_proto) {
|
||||
api2::builder::Graph graph;
|
||||
auto& subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
|
||||
|
@ -59,7 +58,8 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
graph.In(kSampleRateTag).SetName(kSampleRateName) >>
|
||||
subgraph.In(kSampleRateTag);
|
||||
}
|
||||
subgraph.GetOptions<AudioClassifierOptionsProto>().Swap(options_proto.get());
|
||||
subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
graph.Out(kClassificationResultTag);
|
||||
|
@ -67,10 +67,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
}
|
||||
|
||||
// Converts the user-facing AudioClassifierOptions struct to the internal
|
||||
// AudioClassifierOptions proto.
|
||||
std::unique_ptr<AudioClassifierOptionsProto>
|
||||
// AudioClassifierGraphOptions proto.
|
||||
std::unique_ptr<proto::AudioClassifierGraphOptions>
|
||||
ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
|
||||
auto options_proto = std::make_unique<AudioClassifierOptionsProto>();
|
||||
auto options_proto = std::make_unique<proto::AudioClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
|
@ -119,7 +119,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
|
|||
};
|
||||
}
|
||||
return core::AudioTaskApiFactory::Create<AudioClassifier,
|
||||
AudioClassifierOptionsProto>(
|
||||
proto::AudioClassifierGraphOptions>(
|
||||
CreateGraphConfig(std::move(options_proto)),
|
||||
std::move(options->base_options.op_resolver), options->running_mode,
|
||||
std::move(packets_callback));
|
||||
|
@ -140,6 +140,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
|
|||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
|
||||
// The options for configuring a mediapipe audio classifier task.
|
||||
struct AudioClassifierOptions {
|
||||
|
@ -162,6 +163,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
|
|||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -28,7 +28,7 @@ limitations under the License.
|
|||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
|
||||
#include "mediapipe/tasks/cc/common.h"
|
||||
#include "mediapipe/tasks/cc/components/classification_postprocessing.h"
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -60,10 +61,9 @@ constexpr char kPacketTag[] = "PACKET";
|
|||
constexpr char kSampleRateTag[] = "SAMPLE_RATE";
|
||||
constexpr char kTensorsTag[] = "TENSORS";
|
||||
constexpr char kTimestampsTag[] = "TIMESTAMPS";
|
||||
using AudioClassifierOptionsProto =
|
||||
audio_classifier::proto::AudioClassifierOptions;
|
||||
|
||||
absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) {
|
||||
absl::Status SanityCheckOptions(
|
||||
const proto::AudioClassifierGraphOptions& options) {
|
||||
if (options.base_options().use_stream_mode() &&
|
||||
!options.has_default_input_audio_sample_rate()) {
|
||||
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
|
||||
|
@ -111,7 +111,7 @@ void ConfigureAudioToTensorCalculator(
|
|||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification.
|
||||
// An "AudioClassifierGraph" performs audio classification.
|
||||
// - Accepts CPU audio buffer and outputs classification results on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -129,12 +129,12 @@ void ConfigureAudioToTensorCalculator(
|
|||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.audio.AudioClassifierGraph"
|
||||
// calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
|
||||
// input_stream: "AUDIO:audio_in"
|
||||
// input_stream: "SAMPLE_RATE:sample_rate_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
|
||||
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -152,16 +152,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<AudioClassifierOptionsProto>(sc));
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
const bool use_stream_mode = sc->Options<AudioClassifierOptionsProto>()
|
||||
.base_options()
|
||||
.use_stream_mode();
|
||||
const bool use_stream_mode =
|
||||
sc->Options<proto::AudioClassifierGraphOptions>()
|
||||
.base_options()
|
||||
.use_stream_mode();
|
||||
ASSIGN_OR_RETURN(
|
||||
auto classification_result_out,
|
||||
BuildAudioClassificationTask(
|
||||
sc->Options<AudioClassifierOptionsProto>(), *model_resources,
|
||||
sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<Matrix>(kAudioTag)],
|
||||
use_stream_mode
|
||||
? absl::nullopt
|
||||
|
@ -178,14 +180,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
// buffer (mediapipe::Matrix) and the corresponding sample rate (double) as
|
||||
// the inputs and returns one classification result per input audio buffer.
|
||||
//
|
||||
// task_options: the mediapipe tasks AudioClassifierOptions proto.
|
||||
// task_options: the mediapipe tasks AudioClassifierGraphOptions proto.
|
||||
// model_resources: the ModelSources object initialized from an audio
|
||||
// classifier model file with model metadata.
|
||||
// audio_in: (mediapipe::Matrix) stream to run audio classification on.
|
||||
// sample_rate_in: (double) optional stream of the input audio sample rate.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
|
||||
const AudioClassifierOptionsProto& task_options,
|
||||
const proto::AudioClassifierGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Matrix> audio_in,
|
||||
absl::optional<Source<double>> sample_rate_in, Graph& graph) {
|
||||
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
|
||||
|
@ -257,8 +259,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph);
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph);
|
||||
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace audio {
|
||||
namespace audio_classifier {
|
||||
namespace {
|
||||
|
||||
using ::absl::StatusOr;
|
||||
|
@ -557,6 +558,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace audio_classifier
|
||||
} // namespace audio
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "audio_classifier_options_proto",
|
||||
srcs = ["audio_classifier_options.proto"],
|
||||
name = "audio_classifier_graph_options_proto",
|
||||
srcs = ["audio_classifier_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
|
|
@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message AudioClassifierOptions {
|
||||
message AudioClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional AudioClassifierOptions ext = 451755788;
|
||||
optional AudioClassifierGraphOptions ext = 451755788;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
|
@ -33,7 +33,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
|
||||
"//mediapipe/tasks/cc/core:model_resources",
|
||||
"//mediapipe/tasks/cc/core:model_task_graph",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -61,7 +61,7 @@ cc_library(
|
|||
"//mediapipe/tasks/cc/vision/core:base_vision_task_api",
|
||||
"//mediapipe/tasks/cc/vision/core:running_mode",
|
||||
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto",
|
||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
|
|
|
@ -36,11 +36,12 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/core/utils.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/running_mode.h"
|
||||
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -52,12 +53,10 @@ constexpr char kImageTag[] = "IMAGE";
|
|||
constexpr char kNormRectName[] = "norm_rect_in";
|
||||
constexpr char kNormRectTag[] = "NORM_RECT";
|
||||
constexpr char kSubgraphTypeName[] =
|
||||
"mediapipe.tasks.vision.ImageClassifierGraph";
|
||||
"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
|
||||
constexpr int kMicroSecondsPerMilliSecond = 1000;
|
||||
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using ImageClassifierOptionsProto =
|
||||
image_classifier::proto::ImageClassifierOptions;
|
||||
|
||||
// Builds a NormalizedRect covering the entire image.
|
||||
NormalizedRect BuildFullImageNormRect() {
|
||||
|
@ -70,17 +69,17 @@ NormalizedRect BuildFullImageNormRect() {
|
|||
}
|
||||
|
||||
// Creates a MediaPipe graph config that contains a subgraph node of
|
||||
// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the
|
||||
// live stream mode, a "FlowLimiterCalculator" will be added to limit the number
|
||||
// of frames in flight.
|
||||
// type "ImageClassifierGraph". If the task is running in the live stream mode,
|
||||
// a "FlowLimiterCalculator" will be added to limit the number of frames in
|
||||
// flight.
|
||||
CalculatorGraphConfig CreateGraphConfig(
|
||||
std::unique_ptr<ImageClassifierOptionsProto> options_proto,
|
||||
std::unique_ptr<proto::ImageClassifierGraphOptions> options_proto,
|
||||
bool enable_flow_limiting) {
|
||||
api2::builder::Graph graph;
|
||||
graph.In(kImageTag).SetName(kImageInStreamName);
|
||||
graph.In(kNormRectTag).SetName(kNormRectName);
|
||||
auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
|
||||
task_subgraph.GetOptions<ImageClassifierOptionsProto>().Swap(
|
||||
task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
|
||||
options_proto.get());
|
||||
task_subgraph.Out(kClassificationResultTag)
|
||||
.SetName(kClassificationResultStreamName) >>
|
||||
|
@ -98,10 +97,10 @@ CalculatorGraphConfig CreateGraphConfig(
|
|||
}
|
||||
|
||||
// Converts the user-facing ImageClassifierOptions struct to the internal
|
||||
// ImageClassifierOptions proto.
|
||||
std::unique_ptr<ImageClassifierOptionsProto>
|
||||
// ImageClassifierGraphOptions proto.
|
||||
std::unique_ptr<proto::ImageClassifierGraphOptions>
|
||||
ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
|
||||
auto options_proto = std::make_unique<ImageClassifierOptionsProto>();
|
||||
auto options_proto = std::make_unique<proto::ImageClassifierGraphOptions>();
|
||||
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
|
||||
tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
|
||||
options_proto->mutable_base_options()->Swap(base_options_proto.get());
|
||||
|
@ -145,7 +144,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
|
|||
};
|
||||
}
|
||||
return core::VisionTaskApiFactory::Create<ImageClassifier,
|
||||
ImageClassifierOptionsProto>(
|
||||
proto::ImageClassifierGraphOptions>(
|
||||
CreateGraphConfig(
|
||||
std::move(options_proto),
|
||||
options->running_mode == core::RunningMode::LIVE_STREAM),
|
||||
|
@ -214,6 +213,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
|
|||
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
|
||||
}
|
||||
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
|
||||
// The options for configuring a Mediapipe image classifier task.
|
||||
struct ImageClassifierOptions {
|
||||
|
@ -161,6 +162,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
|
|||
absl::Status Close() { return runner_->Close(); }
|
||||
};
|
||||
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -29,11 +29,12 @@ limitations under the License.
|
|||
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/model_resources.h"
|
||||
#include "mediapipe/tasks/cc/core/model_task_graph.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h"
|
||||
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
|
||||
namespace {
|
||||
|
||||
|
@ -42,8 +43,6 @@ using ::mediapipe::api2::Output;
|
|||
using ::mediapipe::api2::builder::GenericNode;
|
||||
using ::mediapipe::api2::builder::Graph;
|
||||
using ::mediapipe::api2::builder::Source;
|
||||
using ImageClassifierOptionsProto =
|
||||
image_classifier::proto::ImageClassifierOptions;
|
||||
|
||||
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
|
||||
|
||||
|
@ -61,8 +60,7 @@ struct ImageClassifierOutputStreams {
|
|||
|
||||
} // namespace
|
||||
|
||||
// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image
|
||||
// classification.
|
||||
// An "ImageClassifierGraph" performs image classification.
|
||||
// - Accepts CPU input images and outputs classifications on CPU.
|
||||
//
|
||||
// Inputs:
|
||||
|
@ -80,12 +78,12 @@ struct ImageClassifierOutputStreams {
|
|||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "mediapipe.tasks.vision.ImageClassifierGraph"
|
||||
// calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
|
||||
// input_stream: "IMAGE:image_in"
|
||||
// output_stream: "CLASSIFICATION_RESULT:classification_result_out"
|
||||
// output_stream: "IMAGE:image_out"
|
||||
// options {
|
||||
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext]
|
||||
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
|
||||
// {
|
||||
// base_options {
|
||||
// model_asset {
|
||||
|
@ -104,13 +102,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
public:
|
||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||
SubgraphContext* sc) override {
|
||||
ASSIGN_OR_RETURN(const auto* model_resources,
|
||||
CreateModelResources<ImageClassifierOptionsProto>(sc));
|
||||
ASSIGN_OR_RETURN(
|
||||
const auto* model_resources,
|
||||
CreateModelResources<proto::ImageClassifierGraphOptions>(sc));
|
||||
Graph graph;
|
||||
ASSIGN_OR_RETURN(
|
||||
auto output_streams,
|
||||
BuildImageClassificationTask(
|
||||
sc->Options<ImageClassifierOptionsProto>(), *model_resources,
|
||||
sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
|
||||
graph[Input<Image>(kImageTag)],
|
||||
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
|
||||
output_streams.classification_result >>
|
||||
|
@ -125,13 +124,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
// (mediapipe::Image) as input and returns one classification result per input
|
||||
// image.
|
||||
//
|
||||
// task_options: the mediapipe tasks ImageClassifierOptions.
|
||||
// task_options: the mediapipe tasks ImageClassifierGraphOptions.
|
||||
// model_resources: the ModelSources object initialized from an image
|
||||
// classification model file with model metadata.
|
||||
// image_in: (mediapipe::Image) stream to run classification on.
|
||||
// graph: the mediapipe builder::Graph instance to be updated.
|
||||
absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask(
|
||||
const ImageClassifierOptionsProto& task_options,
|
||||
const proto::ImageClassifierGraphOptions& task_options,
|
||||
const core::ModelResources& model_resources, Source<Image> image_in,
|
||||
Source<NormalizedRect> norm_rect_in, Graph& graph) {
|
||||
// Adds preprocessing calculators and connects them to the graph input image
|
||||
|
@ -168,8 +167,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
|
|||
/*image=*/preprocessing[Output<Image>(kImageTag)]};
|
||||
}
|
||||
};
|
||||
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph);
|
||||
REGISTER_MEDIAPIPE_GRAPH(
|
||||
::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph);
|
||||
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -44,6 +44,7 @@ limitations under the License.
|
|||
namespace mediapipe {
|
||||
namespace tasks {
|
||||
namespace vision {
|
||||
namespace image_classifier {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::file::JoinPath;
|
||||
|
@ -814,6 +815,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
|
|||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace image_classifier
|
||||
} // namespace vision
|
||||
} // namespace tasks
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,8 +19,8 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
licenses(["notice"])
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "image_classifier_options_proto",
|
||||
srcs = ["image_classifier_options.proto"],
|
||||
name = "image_classifier_graph_options_proto",
|
||||
srcs = ["image_classifier_graph_options.proto"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
|
|
|
@ -21,9 +21,9 @@ import "mediapipe/framework/calculator.proto";
|
|||
import "mediapipe/tasks/cc/components/proto/classifier_options.proto";
|
||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||
|
||||
message ImageClassifierOptions {
|
||||
message ImageClassifierGraphOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ImageClassifierOptions ext = 456383383;
|
||||
optional ImageClassifierGraphOptions ext = 456383383;
|
||||
}
|
||||
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite
|
||||
// model file with metadata, accelerator options, etc.
|
Loading…
Reference in New Issue
Block a user