Merge branch 'google:master' into image-classification-python-impl

This commit is contained in:
Kinar R 2022-10-04 02:19:22 +05:30 committed by GitHub
commit aac7ff946f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
116 changed files with 6077 additions and 500 deletions

View File

@ -143,9 +143,7 @@ mediapipe_proto_library(
cc_library( cc_library(
name = "packet_frequency_calculator", name = "packet_frequency_calculator",
srcs = ["packet_frequency_calculator.cc"], srcs = ["packet_frequency_calculator.cc"],
visibility = [ visibility = ["//visibility:public"],
"//visibility:public",
],
deps = [ deps = [
"//mediapipe/calculators/util:packet_frequency_calculator_cc_proto", "//mediapipe/calculators/util:packet_frequency_calculator_cc_proto",
"//mediapipe/calculators/util:packet_frequency_cc_proto", "//mediapipe/calculators/util:packet_frequency_cc_proto",
@ -190,9 +188,7 @@ cc_test(
cc_library( cc_library(
name = "packet_latency_calculator", name = "packet_latency_calculator",
srcs = ["packet_latency_calculator.cc"], srcs = ["packet_latency_calculator.cc"],
visibility = [ visibility = ["//visibility:public"],
"//visibility:public",
],
deps = [ deps = [
"//mediapipe/calculators/util:latency_cc_proto", "//mediapipe/calculators/util:latency_cc_proto",
"//mediapipe/calculators/util:packet_latency_calculator_cc_proto", "//mediapipe/calculators/util:packet_latency_calculator_cc_proto",

View File

@ -184,6 +184,17 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) {
text->set_left(label_left_px_); text->set_left(label_left_px_);
text->set_baseline(label_baseline_px + i * label_height_px_); text->set_baseline(label_baseline_px + i * label_height_px_);
text->set_font_face(options_.font_face()); text->set_font_face(options_.font_face());
if (options_.outline_thickness() > 0) {
text->set_outline_thickness(options_.outline_thickness());
if (options_.outline_color_size() > 0) {
*(text->mutable_outline_color()) =
options_.outline_color(i % options_.outline_color_size());
} else {
text->mutable_outline_color()->set_r(0);
text->mutable_outline_color()->set_g(0);
text->mutable_outline_color()->set_b(0);
}
}
} }
cc->Outputs() cc->Outputs()
.Tag(kRenderDataTag) .Tag(kRenderDataTag)

View File

@ -30,6 +30,13 @@ message LabelsToRenderDataCalculatorOptions {
// Thickness for drawing the label(s). // Thickness for drawing the label(s).
optional double thickness = 2 [default = 2]; optional double thickness = 2 [default = 2];
// Color of outline around each character, if any. One per label, as with
// color attribute.
repeated Color outline_color = 12;
// Thickness of outline around each character.
optional double outline_thickness = 11;
// The font height in absolute pixels. // The font height in absolute pixels.
optional int32 font_height_px = 3 [default = 50]; optional int32 font_height_px = 3 [default = 50];

View File

@ -185,7 +185,10 @@ void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) {
<< "Updated existing texture which had not been marked for reuse!"; << "Updated existing texture which had not been marked for reuse!";
CHECK(prod_token); CHECK(prod_token);
producer_sync_ = std::move(prod_token); producer_sync_ = std::move(prod_token);
producer_context_ = producer_sync_->GetContext(); const auto& synced_context = producer_sync_->GetContext();
if (synced_context) {
producer_context_ = synced_context;
}
} }
void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const { void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const {

View File

@ -34,6 +34,7 @@ android_library(
android_library( android_library(
name = "android_framework_no_mff", name = "android_framework_no_mff",
proguard_specs = [":proguard.pgcfg"], proguard_specs = [":proguard.pgcfg"],
visibility = ["//visibility:public"],
exports = [ exports = [
":android_framework_no_proguard", ":android_framework_no_proguard",
], ],

View File

@ -48,6 +48,8 @@ pybind_extension(
"//mediapipe/python/pybind:timestamp", "//mediapipe/python/pybind:timestamp",
"//mediapipe/python/pybind:validated_graph_config", "//mediapipe/python/pybind:validated_graph_config",
"//mediapipe/tasks/python/core/pybind:task_runner", "//mediapipe/tasks/python/core/pybind:task_runner",
"@com_google_absl//absl/strings:str_format",
"@stblib//:stb_image",
# Type registration. # Type registration.
"//mediapipe/framework:basic_types_registration", "//mediapipe/framework:basic_types_registration",
"//mediapipe/framework/formats:classification_registration", "//mediapipe/framework/formats:classification_registration",

View File

@ -15,6 +15,7 @@
"""Tests for mediapipe.python._framework_bindings.image.""" """Tests for mediapipe.python._framework_bindings.image."""
import gc import gc
import os
import random import random
import sys import sys
@ -23,6 +24,7 @@ import cv2
import numpy as np import numpy as np
import PIL.Image import PIL.Image
# resources dependency
from mediapipe.python._framework_bindings import image from mediapipe.python._framework_bindings import image
from mediapipe.python._framework_bindings import image_frame from mediapipe.python._framework_bindings import image_frame
@ -185,6 +187,5 @@ class ImageTest(absltest.TestCase):
gc.collect() gc.collect()
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count) self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View File

@ -45,6 +45,8 @@ pybind_library(
":util", ":util",
"//mediapipe/framework:type_map", "//mediapipe/framework:type_map",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"@com_google_absl//absl/strings:str_format",
"@stblib//:stb_image",
], ],
) )

View File

@ -16,9 +16,11 @@
#include <memory> #include <memory>
#include "absl/strings/str_format.h"
#include "mediapipe/python/pybind/image_frame_util.h" #include "mediapipe/python/pybind/image_frame_util.h"
#include "mediapipe/python/pybind/util.h" #include "mediapipe/python/pybind/util.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "stb_image.h"
namespace mediapipe { namespace mediapipe {
namespace python { namespace python {
@ -225,6 +227,62 @@ void ImageSubmodule(pybind11::module* module) {
image.is_aligned(16) image.is_aligned(16)
)doc"); )doc");
image.def_static(
"create_from_file",
[](const std::string& file_name) {
int width;
int height;
int channels;
auto* image_data =
stbi_load(file_name.c_str(), &width, &height, &channels,
/*desired_channels=*/0);
if (image_data == nullptr) {
throw RaisePyError(PyExc_RuntimeError,
absl::StrFormat("Image decoding failed (%s): %s",
stbi_failure_reason(), file_name)
.c_str());
}
ImageFrameSharedPtr image_frame;
switch (channels) {
case 1:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::GRAY8, width, height, width, image_data,
stbi_image_free);
break;
case 3:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGB, width, height, 3 * width, image_data,
stbi_image_free);
break;
case 4:
image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGBA, width, height, 4 * width, image_data,
stbi_image_free);
break;
default:
throw RaisePyError(
PyExc_RuntimeError,
absl::StrFormat(
"Expected image with 1 (grayscale), 3 (RGB) or 4 "
"(RGBA) channels, found %d channels.",
channels)
.c_str());
}
return Image(std::move(image_frame));
},
R"doc(Creates `Image` object from the image file.
Args:
file_name: Image file name.
Returns:
`Image` object.
Raises:
RuntimeError if the image file can't be decoded.
)doc",
py::arg("file_name"));
image.def_property_readonly("width", &Image::width) image.def_property_readonly("width", &Image::width)
.def_property_readonly("height", &Image::height) .def_property_readonly("height", &Image::height)
.def_property_readonly("channels", &Image::channels) .def_property_readonly("channels", &Image::channels)

View File

@ -33,11 +33,12 @@ cc_library(
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/utils:audio_tensor_specs", "//mediapipe/tasks/cc/audio/utils:audio_tensor_specs",
"//mediapipe/tasks/cc/components:classification_postprocessing", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto", "//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",
@ -60,12 +61,13 @@ cc_library(
":audio_classifier_graph", ":audio_classifier_graph",
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_options_cc_proto", "//mediapipe/tasks/cc/audio/audio_classifier/proto:audio_classifier_graph_options_cc_proto",
"//mediapipe/tasks/cc/audio/core:audio_task_api_factory", "//mediapipe/tasks/cc/audio/core:audio_task_api_factory",
"//mediapipe/tasks/cc/audio/core:base_audio_task_api", "//mediapipe/tasks/cc/audio/core:base_audio_task_api",
"//mediapipe/tasks/cc/audio/core:running_mode", "//mediapipe/tasks/cc/audio/core:running_mode",
"//mediapipe/tasks/cc/components:classifier_options", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto", "//mediapipe/tasks/cc/core/proto:inference_subgraph_cc_proto",

View File

@ -22,10 +22,11 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/api2/builder.h" #include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h" #include "mediapipe/tasks/cc/audio/core/audio_task_api_factory.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -33,8 +34,12 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
namespace { namespace {
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAudioStreamName[] = "audio_in"; constexpr char kAudioStreamName[] = "audio_in";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
constexpr char kClassificationResultStreamName[] = "classification_result_out"; constexpr char kClassificationResultStreamName[] = "classification_result_out";
@ -42,16 +47,13 @@ constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT";
constexpr char kSampleRateName[] = "sample_rate_in"; constexpr char kSampleRateName[] = "sample_rate_in";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.audio.AudioClassifierGraph"; "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using AudioClassifierOptionsProto =
audio_classifier::proto::AudioClassifierOptions;
// Creates a MediaPipe graph config that only contains a single subgraph node of // Creates a MediaPipe graph config that only contains a single subgraph node of
// "mediapipe.tasks.audio.AudioClassifierGraph". // type "AudioClassifierGraph".
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<AudioClassifierOptionsProto> options_proto) { std::unique_ptr<proto::AudioClassifierGraphOptions> options_proto) {
api2::builder::Graph graph; api2::builder::Graph graph;
auto& subgraph = graph.AddNode(kSubgraphTypeName); auto& subgraph = graph.AddNode(kSubgraphTypeName);
graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag); graph.In(kAudioTag).SetName(kAudioStreamName) >> subgraph.In(kAudioTag);
@ -59,7 +61,8 @@ CalculatorGraphConfig CreateGraphConfig(
graph.In(kSampleRateTag).SetName(kSampleRateName) >> graph.In(kSampleRateTag).SetName(kSampleRateName) >>
subgraph.In(kSampleRateTag); subgraph.In(kSampleRateTag);
} }
subgraph.GetOptions<AudioClassifierOptionsProto>().Swap(options_proto.get()); subgraph.GetOptions<proto::AudioClassifierGraphOptions>().Swap(
options_proto.get());
subgraph.Out(kClassificationResultTag) subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >> .SetName(kClassificationResultStreamName) >>
graph.Out(kClassificationResultTag); graph.Out(kClassificationResultTag);
@ -67,18 +70,18 @@ CalculatorGraphConfig CreateGraphConfig(
} }
// Converts the user-facing AudioClassifierOptions struct to the internal // Converts the user-facing AudioClassifierOptions struct to the internal
// AudioClassifierOptions proto. // AudioClassifierGraphOptions proto.
std::unique_ptr<AudioClassifierOptionsProto> std::unique_ptr<proto::AudioClassifierGraphOptions>
ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) {
auto options_proto = std::make_unique<AudioClassifierOptionsProto>(); auto options_proto = std::make_unique<proto::AudioClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>( auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode == core::RunningMode::AUDIO_STREAM); options->running_mode == core::RunningMode::AUDIO_STREAM);
auto classifier_options_proto = auto classifier_options_proto =
std::make_unique<tasks::components::proto::ClassifierOptions>( std::make_unique<components::processors::proto::ClassifierOptions>(
components::ConvertClassifierOptionsToProto( components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options))); &(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap( options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get()); classifier_options_proto.get());
@ -119,7 +122,7 @@ absl::StatusOr<std::unique_ptr<AudioClassifier>> AudioClassifier::Create(
}; };
} }
return core::AudioTaskApiFactory::Create<AudioClassifier, return core::AudioTaskApiFactory::Create<AudioClassifier,
AudioClassifierOptionsProto>( proto::AudioClassifierGraphOptions>(
CreateGraphConfig(std::move(options_proto)), CreateGraphConfig(std::move(options_proto)),
std::move(options->base_options.op_resolver), options->running_mode, std::move(options->base_options.op_resolver), options->running_mode,
std::move(packets_callback)); std::move(packets_callback));
@ -140,6 +143,7 @@ absl::Status AudioClassifier::ClassifyAsync(Matrix audio_block,
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -23,13 +23,14 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h" #include "mediapipe/tasks/cc/audio/core/base_audio_task_api.h"
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
// The options for configuring a mediapipe audio classifier task. // The options for configuring a mediapipe audio classifier task.
struct AudioClassifierOptions { struct AudioClassifierOptions {
@ -39,7 +40,7 @@ struct AudioClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
components::ClassifierOptions classifier_options; components::processors::ClassifierOptions classifier_options;
// The running mode of the audio classifier. Default to the audio clips mode. // The running mode of the audio classifier. Default to the audio clips mode.
// Audio classifier has two running modes: // Audio classifier has two running modes:
@ -58,8 +59,9 @@ struct AudioClassifierOptions {
// The user-defined result callback for processing audio stream data. // The user-defined result callback for processing audio stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::AUDIO_STREAM. // to RunningMode::AUDIO_STREAM.
std::function<void(absl::StatusOr<ClassificationResult>)> result_callback = std::function<void(
nullptr; absl::StatusOr<components::containers::proto::ClassificationResult>)>
result_callback = nullptr;
}; };
// Performs audio classification on audio clips or audio stream. // Performs audio classification on audio clips or audio stream.
@ -131,8 +133,8 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
// framed audio clip. // framed audio clip.
// TODO: Use `sample_rate` in AudioClassifierOptions by default // TODO: Use `sample_rate` in AudioClassifierOptions by default
// and makes `audio_sample_rate` optional. // and makes `audio_sample_rate` optional.
absl::StatusOr<ClassificationResult> Classify(mediapipe::Matrix audio_clip, absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
double audio_sample_rate); mediapipe::Matrix audio_clip, double audio_sample_rate);
// Sends audio data (a block in a continuous audio stream) to perform audio // Sends audio data (a block in a continuous audio stream) to perform audio
// classification. Only use this method when the AudioClassifier is created // classification. Only use this method when the AudioClassifier is created
@ -162,6 +164,7 @@ class AudioClassifier : tasks::audio::core::BaseAudioTaskApi {
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
}; };
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -28,12 +28,12 @@ limitations under the License.
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_options.pb.h" #include "mediapipe/tasks/cc/audio/audio_classifier/proto/audio_classifier_graph_options.pb.h"
#include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h" #include "mediapipe/tasks/cc/audio/utils/audio_tensor_specs.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -44,6 +44,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
namespace { namespace {
@ -52,6 +53,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
constexpr char kAtPrestreamTag[] = "AT_PRESTREAM"; constexpr char kAtPrestreamTag[] = "AT_PRESTREAM";
constexpr char kAudioTag[] = "AUDIO"; constexpr char kAudioTag[] = "AUDIO";
@ -60,10 +62,9 @@ constexpr char kPacketTag[] = "PACKET";
constexpr char kSampleRateTag[] = "SAMPLE_RATE"; constexpr char kSampleRateTag[] = "SAMPLE_RATE";
constexpr char kTensorsTag[] = "TENSORS"; constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
using AudioClassifierOptionsProto =
audio_classifier::proto::AudioClassifierOptions;
absl::Status SanityCheckOptions(const AudioClassifierOptionsProto& options) { absl::Status SanityCheckOptions(
const proto::AudioClassifierGraphOptions& options) {
if (options.base_options().use_stream_mode() && if (options.base_options().use_stream_mode() &&
!options.has_default_input_audio_sample_rate()) { !options.has_default_input_audio_sample_rate()) {
return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument, return CreateStatusWithPayload(absl::StatusCode::kInvalidArgument,
@ -111,7 +112,7 @@ void ConfigureAudioToTensorCalculator(
} // namespace } // namespace
// A "mediapipe.tasks.audio.AudioClassifierGraph" performs audio classification. // An "AudioClassifierGraph" performs audio classification.
// - Accepts CPU audio buffer and outputs classification results on CPU. // - Accepts CPU audio buffer and outputs classification results on CPU.
// //
// Inputs: // Inputs:
@ -129,12 +130,12 @@ void ConfigureAudioToTensorCalculator(
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.audio.AudioClassifierGraph" // calculator: "mediapipe.tasks.audio.audio_classifier.AudioClassifierGraph"
// input_stream: "AUDIO:audio_in" // input_stream: "AUDIO:audio_in"
// input_stream: "SAMPLE_RATE:sample_rate_in" // input_stream: "SAMPLE_RATE:sample_rate_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// options { // options {
// [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext] // [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierGraphOptions.ext]
// { // {
// base_options { // base_options {
// model_asset { // model_asset {
@ -152,16 +153,18 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(
CreateModelResources<AudioClassifierOptionsProto>(sc)); const auto* model_resources,
CreateModelResources<proto::AudioClassifierGraphOptions>(sc));
Graph graph; Graph graph;
const bool use_stream_mode = sc->Options<AudioClassifierOptionsProto>() const bool use_stream_mode =
sc->Options<proto::AudioClassifierGraphOptions>()
.base_options() .base_options()
.use_stream_mode(); .use_stream_mode();
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto classification_result_out,
BuildAudioClassificationTask( BuildAudioClassificationTask(
sc->Options<AudioClassifierOptionsProto>(), *model_resources, sc->Options<proto::AudioClassifierGraphOptions>(), *model_resources,
graph[Input<Matrix>(kAudioTag)], graph[Input<Matrix>(kAudioTag)],
use_stream_mode use_stream_mode
? absl::nullopt ? absl::nullopt
@ -178,14 +181,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// buffer (mediapipe::Matrix) and the corresponding sample rate (double) as // buffer (mediapipe::Matrix) and the corresponding sample rate (double) as
// the inputs and returns one classification result per input audio buffer. // the inputs and returns one classification result per input audio buffer.
// //
// task_options: the mediapipe tasks AudioClassifierOptions proto. // task_options: the mediapipe tasks AudioClassifierGraphOptions proto.
// model_resources: the ModelSources object initialized from an audio // model_resources: the ModelSources object initialized from an audio
// classifier model file with model metadata. // classifier model file with model metadata.
// audio_in: (mediapipe::Matrix) stream to run audio classification on. // audio_in: (mediapipe::Matrix) stream to run audio classification on.
// sample_rate_in: (double) optional stream of the input audio sample rate. // sample_rate_in: (double) optional stream of the input audio sample rate.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask( absl::StatusOr<Source<ClassificationResult>> BuildAudioClassificationTask(
const AudioClassifierOptionsProto& task_options, const proto::AudioClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Matrix> audio_in, const core::ModelResources& model_resources, Source<Matrix> audio_in,
absl::optional<Source<double>> sample_rate_in, Graph& graph) { absl::optional<Source<double>> sample_rate_in, Graph& graph) {
MP_RETURN_IF_ERROR(SanityCheckOptions(task_options)); MP_RETURN_IF_ERROR(SanityCheckOptions(task_options));
@ -236,11 +239,14 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects them to the graph output. // Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, task_options.classifier_options(), model_resources, task_options.classifier_options(),
&postprocessing.GetOptions< &postprocessing
tasks::components::ClassificationPostprocessingOptions>())); .GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Time aggregation is only needed for performing audio classification on // Time aggregation is only needed for performing audio classification on
@ -257,8 +263,10 @@ class AudioClassifierGraph : public core::ModelTaskGraph {
} }
}; };
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::audio::AudioClassifierGraph); REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::audio::audio_classifier::AudioClassifierGraph);
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -37,17 +37,19 @@ limitations under the License.
#include "mediapipe/tasks/cc/audio/core/running_mode.h" #include "mediapipe/tasks/cc/audio/core/running_mode.h"
#include "mediapipe/tasks/cc/audio/utils/test_utils.h" #include "mediapipe/tasks/cc/audio/utils/test_utils.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace audio { namespace audio {
namespace audio_classifier {
namespace { namespace {
using ::absl::StatusOr; using ::absl::StatusOr;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -557,6 +559,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) {
} }
} // namespace } // namespace
} // namespace audio_classifier
} // namespace audio } // namespace audio
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,12 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "audio_classifier_options_proto", name = "audio_classifier_graph_options_proto",
srcs = ["audio_classifier_options.proto"], srcs = ["audio_classifier_graph_options.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,12 +18,12 @@ syntax = "proto2";
package mediapipe.tasks.audio.audio_classifier.proto; package mediapipe.tasks.audio.audio_classifier.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message AudioClassifierOptions { message AudioClassifierGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional AudioClassifierOptions ext = 451755788; optional AudioClassifierGraphOptions ext = 451755788;
} }
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite // Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc. // model file with metadata, accelerator options, etc.
@ -31,7 +31,7 @@ message AudioClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2; optional components.processors.proto.ClassifierOptions classifier_options = 2;
// The default sample rate of the input audio. Must be set when the // The default sample rate of the input audio. Must be set when the
// AudioClassifier is configured to process audio stream data. // AudioClassifier is configured to process audio stream data.

View File

@ -58,65 +58,6 @@ cc_library(
# TODO: Enable this test # TODO: Enable this test
cc_library(
name = "classifier_options",
srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"],
deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"],
)
mediapipe_proto_library(
name = "classification_postprocessing_options_proto",
srcs = ["classification_postprocessing_options.proto"],
deps = [
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
],
)
cc_library(
name = "classification_postprocessing",
srcs = ["classification_postprocessing.cc"],
hdrs = ["classification_postprocessing.h"],
deps = [
":classification_postprocessing_options_cc_proto",
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)
cc_library( cc_library(
name = "embedder_options", name = "embedder_options",
srcs = ["embedder_options.cc"], srcs = ["embedder_options.cc"],

View File

@ -37,8 +37,8 @@ cc_library(
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/tasks/cc/components/containers:category_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:category_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
], ],
alwayslink = 1, alwayslink = 1,
@ -128,7 +128,7 @@ cc_library(
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -25,15 +25,15 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions; using ::mediapipe::tasks::ClassificationAggregationCalculatorOptions;
using ::mediapipe::tasks::ClassificationResult; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::Classifications; using ::mediapipe::tasks::components::containers::proto::Classifications;
// Aggregates ClassificationLists into a single ClassificationResult that has // Aggregates ClassificationLists into a single ClassificationResult that has
// 3 dimensions: (classification head, classification timestamp, classification // 3 dimensions: (classification head, classification timestamp, classification

View File

@ -17,12 +17,13 @@ limitations under the License.
#include <vector> #include <vector>
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
// Specialized EndLoopCalculator for Tasks specific types. // Specialized EndLoopCalculator for Tasks specific types.
namespace mediapipe::tasks { namespace mediapipe::tasks {
typedef EndLoopCalculator<std::vector<ClassificationResult>> typedef EndLoopCalculator<
std::vector<components::containers::proto::ClassificationResult>>
EndLoopClassificationResultCalculator; EndLoopClassificationResultCalculator;
REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator);

View File

@ -18,6 +18,24 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library(
name = "category_proto",
srcs = ["category.proto"],
)
mediapipe_proto_library(
name = "classifications_proto",
srcs = ["classifications.proto"],
deps = [
":category_proto",
],
)
mediapipe_proto_library(
name = "embeddings_proto",
srcs = ["embeddings.proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "landmarks_detection_result_proto", name = "landmarks_detection_result_proto",
srcs = [ srcs = [
@ -29,8 +47,3 @@ mediapipe_proto_library(
"//mediapipe/framework/formats:rect_proto", "//mediapipe/framework/formats:rect_proto",
], ],
) )
mediapipe_proto_library(
name = "embeddings_proto",
srcs = ["embeddings.proto"],
)

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks; package mediapipe.tasks.components.containers.proto;
// A single classification result. // A single classification result.
message Category { message Category {

View File

@ -15,9 +15,9 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks; package mediapipe.tasks.components.containers.proto;
import "mediapipe/tasks/cc/components/containers/category.proto"; import "mediapipe/tasks/cc/components/containers/proto/category.proto";
// List of predicted categories with an optional timestamp. // List of predicted categories with an optional timestamp.
message ClassificationEntry { message ClassificationEntry {

View File

@ -0,0 +1,64 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "classifier_options",
srcs = ["classifier_options.cc"],
hdrs = ["classifier_options.h"],
deps = ["//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto"],
)
cc_library(
name = "classification_postprocessing_graph",
srcs = ["classification_postprocessing_graph.cc"],
hdrs = ["classification_postprocessing_graph.h"],
deps = [
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/core:split_vector_calculator_cc_proto",
"//mediapipe/calculators/tensor:tensors_dequantization_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator",
"//mediapipe/calculators/tensor:tensors_to_classification_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_utils",
"//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/components/utils:source_or_node_output",
"//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/metadata:metadata_extractor",
"//mediapipe/tasks/metadata:metadata_schema_cc",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:label_map_util",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/lite/schema:schema_fbs",
],
alwayslink = 1,
)

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include <stdint.h> #include <stdint.h>
@ -37,9 +37,9 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/metadata/metadata_extractor.h" #include "mediapipe/tasks/cc/metadata/metadata_extractor.h"
@ -51,6 +51,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
@ -61,7 +62,7 @@ using ::mediapipe::api2::Timestamp;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::proto::ClassifierOptions; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::mediapipe::tasks::metadata::ModelMetadataExtractor; using ::mediapipe::tasks::metadata::ModelMetadataExtractor;
using ::tflite::ProcessUnit; using ::tflite::ProcessUnit;
@ -79,7 +80,8 @@ constexpr char kTensorsTag[] = "TENSORS";
constexpr char kTimestampsTag[] = "TIMESTAMPS"; constexpr char kTimestampsTag[] = "TIMESTAMPS";
// Performs sanity checks on provided ClassifierOptions. // Performs sanity checks on provided ClassifierOptions.
absl::Status SanityCheckClassifierOptions(const ClassifierOptions& options) { absl::Status SanityCheckClassifierOptions(
const proto::ClassifierOptions& options) {
if (options.max_results() == 0) { if (options.max_results() == 0) {
return CreateStatusWithPayload( return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument, absl::StatusCode::kInvalidArgument,
@ -203,7 +205,7 @@ absl::StatusOr<float> GetScoreThreshold(
// Gets the category allowlist or denylist (if any) as a set of indices. // Gets the category allowlist or denylist (if any) as a set of indices.
absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny( absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
const ClassifierOptions& options, const LabelItems& label_items) { const proto::ClassifierOptions& options, const LabelItems& label_items) {
absl::flat_hash_set<int> category_indices; absl::flat_hash_set<int> category_indices;
// Exit early if no denylist/allowlist. // Exit early if no denylist/allowlist.
if (options.category_denylist_size() == 0 && if (options.category_denylist_size() == 0 &&
@ -239,7 +241,7 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny(
absl::Status ConfigureScoreCalibrationIfAny( absl::Status ConfigureScoreCalibrationIfAny(
const ModelMetadataExtractor& metadata_extractor, int tensor_index, const ModelMetadataExtractor& metadata_extractor, int tensor_index,
ClassificationPostprocessingOptions* options) { proto::ClassificationPostprocessingGraphOptions* options) {
const auto* tensor_metadata = const auto* tensor_metadata =
metadata_extractor.GetOutputTensorMetadata(tensor_index); metadata_extractor.GetOutputTensorMetadata(tensor_index);
if (tensor_metadata == nullptr) { if (tensor_metadata == nullptr) {
@ -283,7 +285,7 @@ absl::Status ConfigureScoreCalibrationIfAny(
// Fills in the TensorsToClassificationCalculatorOptions based on the // Fills in the TensorsToClassificationCalculatorOptions based on the
// classifier options and the (optional) output tensor metadata. // classifier options and the (optional) output tensor metadata.
absl::Status ConfigureTensorsToClassificationCalculator( absl::Status ConfigureTensorsToClassificationCalculator(
const ClassifierOptions& options, const proto::ClassifierOptions& options,
const ModelMetadataExtractor& metadata_extractor, int tensor_index, const ModelMetadataExtractor& metadata_extractor, int tensor_index,
TensorsToClassificationCalculatorOptions* calculator_options) { TensorsToClassificationCalculatorOptions* calculator_options) {
const auto* tensor_metadata = const auto* tensor_metadata =
@ -345,10 +347,10 @@ void ConfigureClassificationAggregationCalculator(
} // namespace } // namespace
absl::Status ConfigureClassificationPostprocessing( absl::Status ConfigureClassificationPostprocessingGraph(
const ModelResources& model_resources, const ModelResources& model_resources,
const ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options) { proto::ClassificationPostprocessingGraphOptions* options) {
MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options)); MP_RETURN_IF_ERROR(SanityCheckClassifierOptions(classifier_options));
ASSIGN_OR_RETURN(const auto heads_properties, ASSIGN_OR_RETURN(const auto heads_properties,
GetClassificationHeadsProperties(model_resources)); GetClassificationHeadsProperties(model_resources));
@ -366,8 +368,8 @@ absl::Status ConfigureClassificationPostprocessing(
return absl::OkStatus(); return absl::OkStatus();
} }
// A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts // A "ClassificationPostprocessingGraph" converts raw tensors into
// raw tensors into ClassificationResult objects. // ClassificationResult objects.
// - Accepts CPU input tensors. // - Accepts CPU input tensors.
// //
// Inputs: // Inputs:
@ -381,10 +383,10 @@ absl::Status ConfigureClassificationPostprocessing(
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results. // The output aggregated classification results.
// //
// The recommended way of using this subgraph is through the GraphBuilder API // The recommended way of using this graph is through the GraphBuilder API
// using the 'ConfigureClassificationPostprocessing()' function. See header file // using the 'ConfigureClassificationPostprocessingGraph()' function. See header
// for more details. // file for more details.
class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { class ClassificationPostprocessingGraph : public mediapipe::Subgraph {
public: public:
absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig( absl::StatusOr<mediapipe::CalculatorGraphConfig> GetConfig(
mediapipe::SubgraphContext* sc) override { mediapipe::SubgraphContext* sc) override {
@ -392,7 +394,7 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto classification_result_out, auto classification_result_out,
BuildClassificationPostprocessing( BuildClassificationPostprocessing(
sc->Options<ClassificationPostprocessingOptions>(), sc->Options<proto::ClassificationPostprocessingGraphOptions>(),
graph[Input<std::vector<Tensor>>(kTensorsTag)], graph[Input<std::vector<Tensor>>(kTensorsTag)],
graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph)); graph[Input<std::vector<Timestamp>>(kTimestampsTag)], graph));
classification_result_out >> classification_result_out >>
@ -401,19 +403,19 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
} }
private: private:
// Adds an on-device classification postprocessing subgraph into the provided // Adds an on-device classification postprocessing graph into the provided
// builder::Graph instance. The classification postprocessing subgraph takes // builder::Graph instance. The classification postprocessing graph takes
// tensors (std::vector<mediapipe::Tensor>) as input and returns one output // tensors (std::vector<mediapipe::Tensor>) as input and returns one output
// stream containing the output classification results (ClassificationResult). // stream containing the output classification results (ClassificationResult).
// //
// options: the on-device ClassificationPostprocessingOptions. // options: the on-device ClassificationPostprocessingGraphOptions.
// tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess. // tensors_in: (std::vector<mediapipe::Tensor>>) tensors to postprocess.
// timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of // timestamps_in: (std::vector<mediapipe::Timestamp>) optional collection of
// timestamps that a single ClassificationResult should aggregate. // timestamps that a single ClassificationResult should aggregate.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<Source<ClassificationResult>> absl::StatusOr<Source<ClassificationResult>>
BuildClassificationPostprocessing( BuildClassificationPostprocessing(
const ClassificationPostprocessingOptions& options, const proto::ClassificationPostprocessingGraphOptions& options,
Source<std::vector<Tensor>> tensors_in, Source<std::vector<Tensor>> tensors_in,
Source<std::vector<Timestamp>> timestamps_in, Graph& graph) { Source<std::vector<Timestamp>> timestamps_in, Graph& graph) {
const int num_heads = options.tensors_to_classifications_options_size(); const int num_heads = options.tensors_to_classifications_options_size();
@ -504,9 +506,11 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph {
kClassificationResultTag)]; kClassificationResultTag)];
} }
}; };
REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::components::ClassificationPostprocessingSubgraph);
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::components::processors::
ClassificationPostprocessingGraph); // NOLINT
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,32 +13,33 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Configures a ClassificationPostprocessing subgraph using the provided model // Configures a ClassificationPostprocessingGraph using the provided model
// resources and ClassifierOptions. // resources and ClassifierOptions.
// - Accepts CPU input tensors. // - Accepts CPU input tensors.
// //
// Example usage: // Example usage:
// //
// auto& postprocessing = // auto& postprocessing =
// graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); // graph.AddNode("mediapipe.tasks.components.processors.ClassificationPostprocessingGraph");
// MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( // MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
// model_resources, // model_resources,
// classifier_options, // classifier_options,
// &preprocessing.GetOptions<ClassificationPostprocessingOptions>())); // &preprocessing.GetOptions<ClassificationPostprocessingGraphOptions>()));
// //
// The resulting ClassificationPostprocessing subgraph has the following I/O: // The resulting ClassificationPostprocessingGraph has the following I/O:
// Inputs: // Inputs:
// TENSORS - std::vector<Tensor> // TENSORS - std::vector<Tensor>
// The output tensors of an InferenceCalculator. // The output tensors of an InferenceCalculator.
@ -49,13 +50,14 @@ namespace components {
// Outputs: // Outputs:
// CLASSIFICATION_RESULT - ClassificationResult // CLASSIFICATION_RESULT - ClassificationResult
// The output aggregated classification results. // The output aggregated classification results.
absl::Status ConfigureClassificationPostprocessing( absl::Status ConfigureClassificationPostprocessingGraph(
const tasks::core::ModelResources& model_resources, const tasks::core::ModelResources& model_resources,
const tasks::components::proto::ClassifierOptions& classifier_options, const proto::ClassifierOptions& classifier_options,
ClassificationPostprocessingOptions* options); proto::ClassificationPostprocessingGraphOptions* options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFICATION_POSTPROCESSING_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFICATION_POSTPROCESSING_GRAPH_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include <map> #include <map>
#include <memory> #include <memory>
@ -42,9 +42,9 @@ limitations under the License.
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h"
#include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
@ -53,6 +53,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
namespace { namespace {
using ::mediapipe::api2::Input; using ::mediapipe::api2::Input;
@ -60,7 +61,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::proto::ClassifierOptions; using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::ModelResources; using ::mediapipe::tasks::core::ModelResources;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::proto::Approximately; using ::testing::proto::Approximately;
@ -101,12 +102,12 @@ TEST_F(ConfigureTest, FailsWithInvalidMaxResults) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.set_max_results(0); options_in.set_max_results(0);
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources, auto status = ConfigureClassificationPostprocessingGraph(
options_in, &options_out); *model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option")); EXPECT_THAT(status.message(), HasSubstr("Invalid `max_results` option"));
@ -116,13 +117,13 @@ TEST_F(ConfigureTest, FailsWithBothAllowlistAndDenylist) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_allowlist("foo"); options_in.add_category_allowlist("foo");
options_in.add_category_denylist("bar"); options_in.add_category_denylist("bar");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources, auto status = ConfigureClassificationPostprocessingGraph(
options_in, &options_out); *model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options")); EXPECT_THAT(status.message(), HasSubstr("mutually exclusive options"));
@ -132,12 +133,12 @@ TEST_F(ConfigureTest, FailsWithAllowlistAndNoMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_allowlist("foo"); options_in.add_category_allowlist("foo");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
auto status = ConfigureClassificationPostprocessing(*model_resources, auto status = ConfigureClassificationPostprocessingGraph(
options_in, &options_out); *model_resources, options_in, &options_out);
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
EXPECT_THAT( EXPECT_THAT(
@ -149,11 +150,11 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: [] R"pb(score_calibration_options: []
@ -171,12 +172,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.set_max_results(3); options_in.set_max_results(3);
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: [] R"pb(score_calibration_options: []
@ -194,12 +195,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithoutMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.set_score_threshold(0.5); options_in.set_score_threshold(0.5);
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
EXPECT_THAT(options_out, Approximately(EqualsProto( EXPECT_THAT(options_out, Approximately(EqualsProto(
R"pb(score_calibration_options: [] R"pb(score_calibration_options: []
@ -217,11 +218,11 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Check label map size and two first elements. // Check label map size and two first elements.
EXPECT_EQ( EXPECT_EQ(
@ -254,12 +255,12 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_allowlist("tench"); options_in.add_category_allowlist("tench");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Clear label map and compare the rest of the options. // Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0) options_out.mutable_tensors_to_classifications_options(0)
@ -283,12 +284,12 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata)); CreateModelResourcesForModel(kQuantizedImageClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
options_in.add_category_denylist("background"); options_in.add_category_denylist("background");
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Clear label map and compare the rest of the options. // Clear label map and compare the rest of the options.
options_out.mutable_tensors_to_classifications_options(0) options_out.mutable_tensors_to_classifications_options(0)
@ -313,11 +314,11 @@ TEST_F(ConfigureTest, SucceedsWithScoreCalibration) {
auto model_resources, auto model_resources,
CreateModelResourcesForModel( CreateModelResourcesForModel(
kQuantizedImageClassifierWithDummyScoreCalibration)); kQuantizedImageClassifierWithDummyScoreCalibration));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Check label map size and two first elements. // Check label map size and two first elements.
EXPECT_EQ( EXPECT_EQ(
@ -362,11 +363,11 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto model_resources, auto model_resources,
CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata)); CreateModelResourcesForModel(kFloatTwoHeadsAudioClassifierWithMetadata));
ClassifierOptions options_in; proto::ClassifierOptions options_in;
ClassificationPostprocessingOptions options_out; proto::ClassificationPostprocessingGraphOptions options_out;
MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, MP_ASSERT_OK(ConfigureClassificationPostprocessingGraph(
options_in, &options_out)); *model_resources, options_in, &options_out));
// Check label maps sizes and first two elements. // Check label maps sizes and first two elements.
EXPECT_EQ( EXPECT_EQ(
options_out.tensors_to_classifications_options(0).label_items_size(), options_out.tensors_to_classifications_options(0).label_items_size(),
@ -414,17 +415,19 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) {
class PostprocessingTest : public tflite_shims::testing::Test { class PostprocessingTest : public tflite_shims::testing::Test {
protected: protected:
absl::StatusOr<OutputStreamPoller> BuildGraph( absl::StatusOr<OutputStreamPoller> BuildGraph(
absl::string_view model_name, const ClassifierOptions& options, absl::string_view model_name, const proto::ClassifierOptions& options,
bool connect_timestamps = false) { bool connect_timestamps = false) {
ASSIGN_OR_RETURN(auto model_resources, ASSIGN_OR_RETURN(auto model_resources,
CreateModelResourcesForModel(model_name)); CreateModelResourcesForModel(model_name));
Graph graph; Graph graph;
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessingGraph(
*model_resources, options, *model_resources, options,
&postprocessing.GetOptions<ClassificationPostprocessingOptions>())); &postprocessing
.GetOptions<proto::ClassificationPostprocessingGraphOptions>()));
graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >> graph[Input<std::vector<Tensor>>(kTensorsTag)].SetName(kTensorsName) >>
postprocessing.In(kTensorsTag); postprocessing.In(kTensorsTag);
if (connect_timestamps) { if (connect_timestamps) {
@ -495,7 +498,7 @@ class PostprocessingTest : public tflite_shims::testing::Test {
TEST_F(PostprocessingTest, SucceedsWithoutMetadata) { TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(3); options.set_max_results(3);
options.set_score_threshold(0.5); options.set_score_threshold(0.5);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
@ -524,7 +527,7 @@ TEST_F(PostprocessingTest, SucceedsWithoutMetadata) {
TEST_F(PostprocessingTest, SucceedsWithMetadata) { TEST_F(PostprocessingTest, SucceedsWithMetadata) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(3); options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options)); auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options));
@ -567,7 +570,7 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) {
TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(3); options.set_max_results(3);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, auto poller,
@ -613,7 +616,7 @@ TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) {
TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(2); options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, auto poller,
@ -673,7 +676,7 @@ TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) {
TEST_F(PostprocessingTest, SucceedsWithTimestamps) { TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
// Build graph. // Build graph.
ClassifierOptions options; proto::ClassifierOptions options;
options.set_max_results(2); options.set_max_results(2);
MP_ASSERT_OK_AND_ASSIGN( MP_ASSERT_OK_AND_ASSIGN(
auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options, auto poller, BuildGraph(kQuantizedImageClassifierWithMetadata, options,
@ -729,6 +732,7 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) {
} }
} // namespace } // namespace
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( proto::ClassifierOptions ConvertClassifierOptionsToProto(
ClassifierOptions* options) { ClassifierOptions* options) {
tasks::components::proto::ClassifierOptions options_proto; proto::ClassifierOptions options_proto;
options_proto.set_display_names_locale(options->display_names_locale); options_proto.set_display_names_locale(options->display_names_locale);
options_proto.set_max_results(options->max_results); options_proto.set_max_results(options->max_results);
options_proto.set_score_threshold(options->score_threshold); options_proto.set_score_threshold(options->score_threshold);
@ -36,6 +37,7 @@ tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto(
return options_proto; return options_proto;
} }
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ #ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ #define MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace components { namespace components {
namespace processors {
// Classifier options for MediaPipe C++ classification Tasks. // Classifier options for MediaPipe C++ classification Tasks.
struct ClassifierOptions { struct ClassifierOptions {
@ -49,11 +50,12 @@ struct ClassifierOptions {
}; };
// Converts a ClassifierOptions to a ClassifierOptionsProto. // Converts a ClassifierOptions to a ClassifierOptionsProto.
tasks::components::proto::ClassifierOptions ConvertClassifierOptionsToProto( proto::ClassifierOptions ConvertClassifierOptionsToProto(
ClassifierOptions* classifier_options); ClassifierOptions* classifier_options);
} // namespace processors
} // namespace components } // namespace components
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_CLASSIFIER_OPTIONS_H_ #endif // MEDIAPIPE_TASKS_CC_COMPONENTS_PROCESSORS_CLASSIFIER_OPTIONS_H_

View File

@ -19,14 +19,18 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "category_proto", name = "classifier_options_proto",
srcs = ["category.proto"], srcs = ["classifier_options.proto"],
) )
mediapipe_proto_library( mediapipe_proto_library(
name = "classifications_proto", name = "classification_postprocessing_graph_options_proto",
srcs = ["classifications.proto"], srcs = ["classification_postprocessing_graph_options.proto"],
deps = [ deps = [
":category_proto", "//mediapipe/calculators/tensor:tensors_to_classification_calculator_proto",
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto",
"//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto",
], ],
) )

View File

@ -15,16 +15,16 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components; package mediapipe.tasks.components.processors.proto;
import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto";
import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto";
message ClassificationPostprocessingOptions { message ClassificationPostprocessingGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ClassificationPostprocessingOptions ext = 460416950; optional ClassificationPostprocessingGraphOptions ext = 460416950;
} }
// Optional mapping between output tensor index and corresponding score // Optional mapping between output tensor index and corresponding score

View File

@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto2"; syntax = "proto2";
package mediapipe.tasks.components.proto; package mediapipe.tasks.components.processors.proto;
// Shared options used by all classification tasks. // Shared options used by all classification tasks.
message ClassifierOptions { message ClassifierOptions {

View File

@ -23,11 +23,6 @@ mediapipe_proto_library(
srcs = ["segmenter_options.proto"], srcs = ["segmenter_options.proto"],
) )
mediapipe_proto_library(
name = "classifier_options_proto",
srcs = ["classifier_options.proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "embedder_options_proto", name = "embedder_options_proto",
srcs = ["embedder_options.proto"], srcs = ["embedder_options.proto"],

View File

@ -42,3 +42,16 @@ cc_test(
"//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto",
], ],
) )
cc_library(
name = "gate",
hdrs = ["gate.h"],
deps = [
"//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:gate_calculator_cc_proto",
"//mediapipe/framework/api2:builder",
"//mediapipe/framework/api2:port",
],
)
# TODO: Enable this test

View File

@ -0,0 +1,160 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_
#define MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_
#include <utility>
#include "mediapipe/calculators/core/gate_calculator.pb.h"
#include "mediapipe/framework/api2/builder.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace utils {
// Utility class that simplifies allowing (gating) multiple streams.
class AllowGate {
public:
AllowGate(api2::builder::Source<bool> allow, api2::builder::Graph& graph)
: node_(AddSourceGate(allow, graph)) {}
AllowGate(api2::builder::SideSource<bool> allow, api2::builder::Graph& graph)
: node_(AddSideSourceGate(allow, graph)) {}
// Move-only
AllowGate(AllowGate&& allow_gate) = default;
AllowGate& operator=(AllowGate&& allow_gate) = default;
template <typename T>
api2::builder::Source<T> Allow(api2::builder::Source<T> source) {
source >> node_.In(index_);
return node_.Out(index_++).Cast<T>();
}
private:
template <typename T>
static api2::builder::GenericNode& AddSourceGate(
T allow, api2::builder::Graph& graph) {
auto& gate_node = graph.AddNode("GateCalculator");
allow >> gate_node.In("ALLOW");
return gate_node;
}
template <typename T>
static api2::builder::GenericNode& AddSideSourceGate(
T allow, api2::builder::Graph& graph) {
auto& gate_node = graph.AddNode("GateCalculator");
allow >> gate_node.SideIn("ALLOW");
return gate_node;
}
api2::builder::GenericNode& node_;
int index_ = 0;
};
// Utility class that simplifies disallowing (gating) multiple streams.
class DisallowGate {
public:
DisallowGate(api2::builder::Source<bool> disallow,
api2::builder::Graph& graph)
: node_(AddSourceGate(disallow, graph)) {}
DisallowGate(api2::builder::SideSource<bool> disallow,
api2::builder::Graph& graph)
: node_(AddSideSourceGate(disallow, graph)) {}
// Move-only
DisallowGate(DisallowGate&& disallow_gate) = default;
DisallowGate& operator=(DisallowGate&& disallow_gate) = default;
template <typename T>
api2::builder::Source<T> Disallow(api2::builder::Source<T> source) {
source >> node_.In(index_);
return node_.Out(index_++).Cast<T>();
}
private:
template <typename T>
static api2::builder::GenericNode& AddSourceGate(
T disallow, api2::builder::Graph& graph) {
auto& gate_node = graph.AddNode("GateCalculator");
auto& gate_node_opts =
gate_node.GetOptions<mediapipe::GateCalculatorOptions>();
// Supposedly, the most popular configuration for MediaPipe Tasks team
// graphs. Hence, intentionally hard coded to catch and verify any other use
// case (should help to workout a common approach and have a recommended way
// of blocking streams).
gate_node_opts.set_empty_packets_as_allow(true);
disallow >> gate_node.In("DISALLOW");
return gate_node;
}
template <typename T>
static api2::builder::GenericNode& AddSideSourceGate(
T disallow, api2::builder::Graph& graph) {
auto& gate_node = graph.AddNode("GateCalculator");
auto& gate_node_opts =
gate_node.GetOptions<mediapipe::GateCalculatorOptions>();
gate_node_opts.set_empty_packets_as_allow(true);
disallow >> gate_node.SideIn("DISALLOW");
return gate_node;
}
api2::builder::GenericNode& node_;
int index_ = 0;
};
// Updates graph to drop @value stream packet if corresponding @condition stream
// packet holds true.
template <class T>
api2::builder::Source<T> DisallowIf(api2::builder::Source<T> value,
api2::builder::Source<bool> condition,
api2::builder::Graph& graph) {
return DisallowGate(condition, graph).Disallow(value);
}
// Updates graph to drop @value stream packet if corresponding @condition stream
// packet holds true.
template <class T>
api2::builder::Source<T> DisallowIf(api2::builder::Source<T> value,
api2::builder::SideSource<bool> condition,
api2::builder::Graph& graph) {
return DisallowGate(condition, graph).Disallow(value);
}
// Updates graph to pass through @value stream packet if corresponding
// @allow stream packet holds true.
template <class T>
api2::builder::Source<T> AllowIf(api2::builder::Source<T> value,
api2::builder::Source<bool> allow,
api2::builder::Graph& graph) {
return AllowGate(allow, graph).Allow(value);
}
// Updates graph to pass through @value stream packet if corresponding
// @allow side stream packet holds true.
template <class T>
api2::builder::Source<T> AllowIf(api2::builder::Source<T> value,
api2::builder::SideSource<bool> allow,
api2::builder::Graph& graph) {
return AllowGate(allow, graph).Allow(value);
}
} // namespace utils
} // namespace components
} // namespace tasks
} // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_COMPONENTS_UTILS_GATE_H_

View File

@ -0,0 +1,229 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/components/utils/gate.h"
#include "mediapipe/framework/api2/builder.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_graph.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
namespace mediapipe {
namespace tasks {
namespace components {
namespace utils {
namespace {
using ::mediapipe::api2::builder::SideSource;
using ::mediapipe::api2::builder::Source;
TEST(DisallowGate, VerifyConfig) {
mediapipe::api2::builder::Graph graph;
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
Source<int> value1 = graph.In("VALUE_1").Cast<int>();
Source<int> value2 = graph.In("VALUE_2").Cast<int>();
Source<int> value3 = graph.In("VALUE_3").Cast<int>();
DisallowGate gate(condition, graph);
gate.Disallow(value1).SetName("gated_stream1");
gate.Disallow(value2).SetName("gated_stream2");
gate.Disallow(value3).SetName("gated_stream3");
EXPECT_THAT(graph.GetConfig(),
testing::EqualsProto(
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "GateCalculator"
input_stream: "__stream_1"
input_stream: "__stream_2"
input_stream: "__stream_3"
input_stream: "DISALLOW:__stream_0"
output_stream: "gated_stream1"
output_stream: "gated_stream2"
output_stream: "gated_stream3"
options {
[mediapipe.GateCalculatorOptions.ext] {
empty_packets_as_allow: true
}
}
}
input_stream: "CONDITION:__stream_0"
input_stream: "VALUE_1:__stream_1"
input_stream: "VALUE_2:__stream_2"
input_stream: "VALUE_3:__stream_3"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(DisallowIf, VerifyConfig) {
mediapipe::api2::builder::Graph graph;
Source<int> value = graph.In("VALUE").Cast<int>();
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
auto gated_stream = DisallowIf(value, condition, graph);
gated_stream.SetName("gated_stream");
EXPECT_THAT(graph.GetConfig(),
testing::EqualsProto(
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "GateCalculator"
input_stream: "__stream_1"
input_stream: "DISALLOW:__stream_0"
output_stream: "gated_stream"
options {
[mediapipe.GateCalculatorOptions.ext] {
empty_packets_as_allow: true
}
}
}
input_stream: "CONDITION:__stream_0"
input_stream: "VALUE:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(DisallowIf, VerifyConfigWithSideCondition) {
mediapipe::api2::builder::Graph graph;
Source<int> value = graph.In("VALUE").Cast<int>();
SideSource<bool> condition = graph.SideIn("CONDITION").Cast<bool>();
auto gated_stream = DisallowIf(value, condition, graph);
gated_stream.SetName("gated_stream");
EXPECT_THAT(graph.GetConfig(),
testing::EqualsProto(
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "GateCalculator"
input_stream: "__stream_0"
output_stream: "gated_stream"
input_side_packet: "DISALLOW:__side_packet_1"
options {
[mediapipe.GateCalculatorOptions.ext] {
empty_packets_as_allow: true
}
}
}
input_stream: "VALUE:__stream_0"
input_side_packet: "CONDITION:__side_packet_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(AllowGate, VerifyConfig) {
mediapipe::api2::builder::Graph graph;
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
Source<int> value1 = graph.In("VALUE_1").Cast<int>();
Source<int> value2 = graph.In("VALUE_2").Cast<int>();
Source<int> value3 = graph.In("VALUE_3").Cast<int>();
AllowGate gate(condition, graph);
gate.Allow(value1).SetName("gated_stream1");
gate.Allow(value2).SetName("gated_stream2");
gate.Allow(value3).SetName("gated_stream3");
EXPECT_THAT(graph.GetConfig(),
testing::EqualsProto(
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "GateCalculator"
input_stream: "__stream_1"
input_stream: "__stream_2"
input_stream: "__stream_3"
input_stream: "ALLOW:__stream_0"
output_stream: "gated_stream1"
output_stream: "gated_stream2"
output_stream: "gated_stream3"
}
input_stream: "CONDITION:__stream_0"
input_stream: "VALUE_1:__stream_1"
input_stream: "VALUE_2:__stream_2"
input_stream: "VALUE_3:__stream_3"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(AllowIf, VerifyConfig) {
mediapipe::api2::builder::Graph graph;
Source<int> value = graph.In("VALUE").Cast<int>();
Source<bool> condition = graph.In("CONDITION").Cast<bool>();
auto gated_stream = AllowIf(value, condition, graph);
gated_stream.SetName("gated_stream");
EXPECT_THAT(graph.GetConfig(),
testing::EqualsProto(
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "GateCalculator"
input_stream: "__stream_1"
input_stream: "ALLOW:__stream_0"
output_stream: "gated_stream"
}
input_stream: "CONDITION:__stream_0"
input_stream: "VALUE:__stream_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
TEST(AllowIf, VerifyConfigWithSideConition) {
mediapipe::api2::builder::Graph graph;
Source<int> value = graph.In("VALUE").Cast<int>();
SideSource<bool> condition = graph.SideIn("CONDITION").Cast<bool>();
auto gated_stream = AllowIf(value, condition, graph);
gated_stream.SetName("gated_stream");
EXPECT_THAT(graph.GetConfig(),
testing::EqualsProto(
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "GateCalculator"
input_stream: "__stream_0"
output_stream: "gated_stream"
input_side_packet: "ALLOW:__side_packet_1"
}
input_stream: "VALUE:__stream_0"
input_side_packet: "CONDITION:__side_packet_1"
)pb")));
CalculatorGraph calcualtor_graph;
MP_EXPECT_OK(calcualtor_graph.Initialize(graph.GetConfig()));
}
} // namespace
} // namespace utils
} // namespace components
} // namespace tasks
} // namespace mediapipe

View File

@ -23,6 +23,7 @@ cc_library(
srcs = ["base_options.cc"], srcs = ["base_options.cc"],
hdrs = ["base_options.h"], hdrs = ["base_options.h"],
deps = [ deps = [
":mediapipe_builtin_op_resolver",
"//mediapipe/calculators/tensor:inference_calculator_cc_proto", "//mediapipe/calculators/tensor:inference_calculator_cc_proto",
"//mediapipe/tasks/cc/core/proto:acceleration_cc_proto", "//mediapipe/tasks/cc/core/proto:acceleration_cc_proto",
"//mediapipe/tasks/cc/core/proto:base_options_cc_proto", "//mediapipe/tasks/cc/core/proto:base_options_cc_proto",
@ -50,6 +51,21 @@ cc_library(
], ],
) )
cc_library(
name = "mediapipe_builtin_op_resolver",
srcs = ["mediapipe_builtin_op_resolver.cc"],
hdrs = ["mediapipe_builtin_op_resolver.h"],
deps = [
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
"//mediapipe/util/tflite/operations:max_pool_argmax",
"//mediapipe/util/tflite/operations:max_unpooling",
"//mediapipe/util/tflite/operations:transform_landmarks",
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
"//mediapipe/util/tflite/operations:transpose_conv_bias",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
# TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator # TODO: Switch to use cc_library_with_tflite after the MediaPipe InferenceCalculator
# supports TFLite-in-GMSCore. # supports TFLite-in-GMSCore.
cc_library( cc_library(

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <string> #include <string>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
@ -63,7 +64,7 @@ struct BaseOptions {
// A non-default OpResolver to support custom Ops or specify a subset of // A non-default OpResolver to support custom Ops or specify a subset of
// built-in Ops. // built-in Ops.
std::unique_ptr<tflite::OpResolver> op_resolver = std::unique_ptr<tflite::OpResolver> op_resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(); absl::make_unique<MediaPipeBuiltinOpResolver>();
}; };
// Converts a BaseOptions to a BaseOptionsProto. // Converts a BaseOptions to a BaseOptionsProto.

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h" #include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h" #include "mediapipe/util/tflite/operations/landmarks_to_transform_matrix.h"
#include "mediapipe/util/tflite/operations/max_pool_argmax.h" #include "mediapipe/util/tflite/operations/max_pool_argmax.h"
@ -21,14 +21,11 @@ limitations under the License.
#include "mediapipe/util/tflite/operations/transform_landmarks.h" #include "mediapipe/util/tflite/operations/transform_landmarks.h"
#include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h" #include "mediapipe/util/tflite/operations/transform_tensor_bilinear.h"
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h" #include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
#include "tensorflow/lite/kernels/register.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace core {
MediaPipeBuiltinOpResolver::MediaPipeBuiltinOpResolver() {
SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
: BuiltinOpResolver() {
AddCustom("MaxPoolingWithArgmax2D", AddCustom("MaxPoolingWithArgmax2D",
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D()); mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D", AddCustom("MaxUnpooling2D",
@ -46,7 +43,6 @@ SelfieSegmentationModelOpResolver::SelfieSegmentationModelOpResolver()
mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(), mediapipe::tflite_operations::RegisterLandmarksToTransformMatrixV2(),
/*version=*/2); /*version=*/2);
} }
} // namespace core
} // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -13,25 +13,23 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ #ifndef MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ #define MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace core {
class MediaPipeBuiltinOpResolver
class SelfieSegmentationModelOpResolver
: public tflite::ops::builtin::BuiltinOpResolver { : public tflite::ops::builtin::BuiltinOpResolver {
public: public:
SelfieSegmentationModelOpResolver(); MediaPipeBuiltinOpResolver();
SelfieSegmentationModelOpResolver( MediaPipeBuiltinOpResolver(const MediaPipeBuiltinOpResolver& r) = delete;
const SelfieSegmentationModelOpResolver& r) = delete;
}; };
} // namespace vision } // namespace core
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe
#endif // MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_IMAGE_SEGMENTER_OP_RESOLVERS_H_ #endif // MEDIAPIPE_TASKS_CC_CORE_MEDIAPIPE_BUILTIN_OP_RESOLVER_H_

View File

@ -18,18 +18,6 @@ package(default_visibility = [
licenses(["notice"]) licenses(["notice"])
cc_library(
name = "hand_detector_op_resolver",
srcs = ["hand_detector_op_resolver.cc"],
hdrs = ["hand_detector_op_resolver.h"],
deps = [
"//mediapipe/util/tflite/operations:max_pool_argmax",
"//mediapipe/util/tflite/operations:max_unpooling",
"//mediapipe/util/tflite/operations:transpose_conv_bias",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
cc_library( cc_library(
name = "hand_detector_graph", name = "hand_detector_graph",
srcs = ["hand_detector_graph.cc"], srcs = ["hand_detector_graph.cc"],

View File

@ -35,11 +35,11 @@ limitations under the License.
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_options.pb.h"
#include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h" #include "mediapipe/tasks/cc/vision/hand_detector/proto/hand_detector_result.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
@ -121,8 +121,8 @@ absl::StatusOr<std::unique_ptr<TaskRunner>> CreateTaskRunner(
hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >> hand_detection.Out(kHandNormRectsTag).SetName(kHandNormRectsName) >>
graph[Output<std::vector<NormalizedRect>>(kHandNormRectsTag)]; graph[Output<std::vector<NormalizedRect>>(kHandNormRectsTag)];
return TaskRunner::Create(graph.GetConfig(), return TaskRunner::Create(
absl::make_unique<HandDetectorOpResolver>()); graph.GetConfig(), std::make_unique<core::MediaPipeBuiltinOpResolver>());
} }
HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) { HandDetectorResult GetExpectedHandDetectorResult(absl::string_view file_name) {

View File

@ -1,35 +0,0 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/vision/hand_detector/hand_detector_op_resolver.h"
#include "mediapipe/util/tflite/operations/max_pool_argmax.h"
#include "mediapipe/util/tflite/operations/max_unpooling.h"
#include "mediapipe/util/tflite/operations/transpose_conv_bias.h"
namespace mediapipe {
namespace tasks {
namespace vision {
HandDetectorOpResolver::HandDetectorOpResolver() {
AddCustom("MaxPoolingWithArgmax2D",
mediapipe::tflite_operations::RegisterMaxPoolingWithArgmax2D());
AddCustom("MaxUnpooling2D",
mediapipe::tflite_operations::RegisterMaxUnpooling2D());
AddCustom("Convolution2DTransposeBias",
mediapipe::tflite_operations::RegisterConvolution2DTransposeBias());
}
} // namespace vision
} // namespace tasks
} // namespace mediapipe

View File

@ -54,10 +54,10 @@ cc_library(
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/tasks/cc:common", "//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",

View File

@ -27,9 +27,9 @@ limitations under the License.
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -49,6 +49,7 @@ using ::mediapipe::api2::Input;
using ::mediapipe::api2::Output; using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto:: using ::mediapipe::tasks::vision::hand_gesture_recognizer::proto::
HandGestureRecognizerSubgraphOptions; HandGestureRecognizerSubgraphOptions;
using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions; using ::mediapipe::tasks::vision::proto::LandmarksToMatrixCalculatorOptions;
@ -218,11 +219,14 @@ class SingleHandGestureRecognizerSubgraph : public core::ModelTaskGraph {
auto inference_output_tensors = inference.Out(kTensorsTag); auto inference_output_tensors = inference.Out(kTensorsTag);
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, graph_options.classifier_options(), model_resources, graph_options.classifier_options(),
&postprocessing.GetOptions< &postprocessing
tasks::components::ClassificationPostprocessingOptions>())); .GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference_output_tensors >> postprocessing.In(kTensorsTag); inference_output_tensors >> postprocessing.In(kTensorsTag);
auto classification_result = auto classification_result =
postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")]; postprocessing[Output<ClassificationResult>("CLASSIFICATION_RESULT")];

View File

@ -26,7 +26,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )
@ -37,7 +37,5 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,7 +18,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.hand_gesture_recognizer.proto; package mediapipe.tasks.vision.hand_gesture_recognizer.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message HandGestureRecognizerSubgraphOptions { message HandGestureRecognizerSubgraphOptions {
@ -31,7 +31,7 @@ message HandGestureRecognizerSubgraphOptions {
// Options for configuring the gesture classifier behavior, such as score // Options for configuring the gesture classifier behavior, such as score
// threshold, number of results, etc. // threshold, number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2; optional components.processors.proto.ClassifierOptions classifier_options = 2;
// Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be // Minimum confidence value ([0.0, 1.0]) for the hand landmarks to be
// considered tracked successfully // considered tracked successfully

View File

@ -0,0 +1,49 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
package(default_visibility = [
"//mediapipe/app/xeno:__subpackages__",
"//mediapipe/tasks:internal",
])
licenses(["notice"])
mediapipe_proto_library(
name = "hand_association_calculator_proto",
srcs = ["hand_association_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "hand_association_calculator",
srcs = ["hand_association_calculator.cc"],
deps = [
":hand_association_calculator_cc_proto",
"//mediapipe/calculators/util:association_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:rectangle",
"//mediapipe/framework/port:status",
"//mediapipe/util:rectangle_util",
],
alwayslink = 1,
)
# TODO: Enable this test

View File

@ -0,0 +1,125 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <utility>
#include <vector>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.pb.h"
#include "mediapipe/util/rectangle_util.h"
namespace mediapipe::api2 {
// HandAssociationCalculator accepts multiple inputs of vectors of
// NormalizedRect. The output is a vector of NormalizedRect that contains
// rects from the input vectors that don't overlap with each other. When two
// rects overlap, the rect that comes in from an earlier input stream is
// kept in the output. If a rect has no ID (i.e. from detection stream),
// then a unique rect ID is assigned for it.
// The rects in multiple input streams are effectively flattened to a single
// list. For example:
// Stream1 : rect 1, rect 2
// Stream2: rect 3, rect 4
// Stream3: rect 5, rect 6
// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6
// In the flattened list, if a rect with a higher index overlaps with a rect a
// lower index, beyond a specified IOU threshold, the rect with the lower
// index will be in the output, and the rect with higher index will be
// discarded.
// TODO: Upgrade this to latest API for calculators
class HandAssociationCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
// Initialize input and output streams.
for (auto& input_stream : cc->Inputs()) {
input_stream.Set<std::vector<NormalizedRect>>();
}
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<HandAssociationCalculatorOptions>();
CHECK_GT(options_.min_similarity_threshold(), 0.0);
CHECK_LE(options_.min_similarity_threshold(), 1.0);
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
ASSIGN_OR_RETURN(auto result, GetNonOverlappingElements(cc));
auto output =
std::make_unique<std::vector<NormalizedRect>>(std::move(result));
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return absl::OkStatus();
}
private:
HandAssociationCalculatorOptions options_;
// Return a list of non-overlapping elements from all input streams, with
// decreasing order of priority based on input stream index and indices
// within an input stream.
absl::StatusOr<std::vector<NormalizedRect>> GetNonOverlappingElements(
CalculatorContext* cc) {
std::vector<NormalizedRect> result;
for (const auto& input_stream : cc->Inputs()) {
if (input_stream.IsEmpty()) {
continue;
}
for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) {
ASSIGN_OR_RETURN(
bool is_overlapping,
mediapipe::DoesRectOverlap(rect, result,
options_.min_similarity_threshold()));
if (!is_overlapping) {
if (!rect.has_rect_id()) {
rect.set_rect_id(GetNextRectId());
}
result.push_back(rect);
}
}
}
return result;
}
private:
// Each NormalizedRect processed by the calculator will be assigned
// an unique id, if it does not already have an ID. The starting ID will be 1.
// Note: This rect_id_ is local to an instance of this calculator. And it is
// expected that the hand tracking graph to have only one instance of
// this association calculator.
int64 rect_id_ = 1;
inline int GetNextRectId() { return rect_id_++; }
};
MEDIAPIPE_REGISTER_NODE(HandAssociationCalculator);
} // namespace mediapipe::api2

View File

@ -13,22 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ syntax = "proto2";
#define MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_
#include "tensorflow/lite/kernels/register.h" package mediapipe;
namespace mediapipe { import "mediapipe/framework/calculator.proto";
namespace tasks {
namespace vision {
class HandDetectorOpResolver : public tflite::ops::builtin::BuiltinOpResolver {
public:
HandDetectorOpResolver();
HandDetectorOpResolver(const HandDetectorOpResolver& r) = delete;
};
} // namespace vision message HandAssociationCalculatorOptions {
} // namespace tasks extend mediapipe.CalculatorOptions {
} // namespace mediapipe optional HandAssociationCalculatorOptions ext = 408244367;
}
#endif // MEDIAPIPE_TASKS_CC_VISION_HAND_DETECTOR_HAND_DETECTOR_OP_RESOLVER_H_ optional float min_similarity_threshold = 1 [default = 1.0];
}

View File

@ -0,0 +1,302 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <vector>
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
namespace mediapipe {
namespace {
class HandAssociationCalculatorTest : public testing::Test {
protected:
HandAssociationCalculatorTest() {
// 0.4 ================
// | | | |
// 0.3 ===================== | NR2 | |
// | | | NR1 | | | NR4 |
// 0.2 | NR0 | =========== ================
// | | | | | |
// 0.1 =====|=============== |
// | NR3 | | |
// 0.0 ================ |
// | NR5 |
// -0.1 ===========
// 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2
// NormalizedRect nr_0.
nr_0_.set_x_center(0.2);
nr_0_.set_y_center(0.2);
nr_0_.set_width(0.2);
nr_0_.set_height(0.2);
// NormalizedRect nr_1.
nr_1_.set_x_center(0.4);
nr_1_.set_y_center(0.2);
nr_1_.set_width(0.2);
nr_1_.set_height(0.2);
// NormalizedRect nr_2.
nr_2_.set_x_center(1.0);
nr_2_.set_y_center(0.3);
nr_2_.set_width(0.2);
nr_2_.set_height(0.2);
// NormalizedRect nr_3.
nr_3_.set_x_center(0.35);
nr_3_.set_y_center(0.15);
nr_3_.set_width(0.3);
nr_3_.set_height(0.3);
// NormalizedRect nr_4.
nr_4_.set_x_center(1.1);
nr_4_.set_y_center(0.3);
nr_4_.set_width(0.2);
nr_4_.set_height(0.2);
// NormalizedRect nr_5.
nr_5_.set_x_center(0.5);
nr_5_.set_y_center(0.05);
nr_5_.set_width(0.3);
nr_5_.set_height(0.3);
}
NormalizedRect nr_0_, nr_1_, nr_2_, nr_3_, nr_4_, nr_5_;
};
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Input Stream 0: nr_0, nr_1, nr_2.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_0->push_back(nr_0_);
input_vec_0->push_back(nr_1_);
input_vec_0->push_back(nr_2_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_3, nr_4.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_3_);
input_vec_1->push_back(nr_4_);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_5.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_5_);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_0 is added 1st.
// nr_1 is added because it does not overlap with nr_0.
// nr_2 is added because it does not overlap with nr_0 or nr_1.
// nr_3 is NOT added because it overlaps with nr_0.
// nr_4 is NOT added because it overlaps with nr_2.
// nr_5 is NOT added because it overlaps with nr_1.
EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
assoc_rects[1].clear_rect_id();
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Input Stream 0: nr_0, nr_1. Tracked hands.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
// Setting ID to a negative number for test only, since newly generated
// ID by HandAssociationCalculator are positive numbers.
nr_0_.set_rect_id(-2);
input_vec_0->push_back(nr_0_);
nr_1_.set_rect_id(-1);
input_vec_0->push_back(nr_1_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_2, nr_3. Newly detected palms.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_2_);
input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_0 is added 1st.
// nr_1 is added because it does not overlap with nr_0.
// nr_2 is added because it does not overlap with nr_0 or nr_1.
// nr_3 is NOT added because it overlaps with nr_0.
EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), -2);
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
EXPECT_EQ(assoc_rects[1].rect_id(), -1);
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 1);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec_0"
input_stream: "input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Input Stream 0: nr_5.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_0->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_4, nr_3
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_4_);
input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_2, nr_1, nr_0.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_2_);
input_vec_2->push_back(nr_1_);
input_vec_2->push_back(nr_0_);
runner.MutableInputs()->Index(2).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_5 is added 1st.
// nr_4 is added because it does not overlap with nr_5.
// nr_3 is NOT added because it overlaps with nr_5.
// nr_2 is NOT added because it overlaps with nr_4.
// nr_1 is NOT added because it overlaps with nr_5.
// nr_0 is added because it does not overlap with nr_5 or nr_4.
EXPECT_EQ(3, assoc_rects.size());
// Outputs are in same order as inputs, and IDs are filled in.
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_));
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
assoc_rects[1].clear_rect_id();
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_));
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_));
}
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator"
input_stream: "input_vec"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
// Just one input stream : nr_3, nr_5.
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
input_vec->push_back(nr_3_);
input_vec->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back(
Adopt(input_vec.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_3 is added 1st.
// nr_5 is NOT added because it overlaps with nr_3.
EXPECT_EQ(1, assoc_rects.size());
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
assoc_rects[0].clear_rect_id();
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_));
}
} // namespace
} // namespace mediapipe

View File

@ -26,14 +26,14 @@ cc_library(
"//mediapipe/framework/api2:port", "//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:classification_postprocessing",
"//mediapipe/tasks/cc/components:classification_postprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components:image_preprocessing", "//mediapipe/tasks/cc/components:image_preprocessing",
"//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto", "//mediapipe/tasks/cc/components:image_preprocessing_options_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/processors:classification_postprocessing_graph",
"//mediapipe/tasks/cc/components/processors/proto:classification_postprocessing_graph_options_cc_proto",
"//mediapipe/tasks/cc/core:model_resources", "//mediapipe/tasks/cc/core:model_resources",
"//mediapipe/tasks/cc/core:model_task_graph", "//mediapipe/tasks/cc/core:model_task_graph",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
alwayslink = 1, alwayslink = 1,
@ -50,9 +50,9 @@ cc_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/tasks/cc/components:classifier_options", "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto",
"//mediapipe/tasks/cc/components/containers:classifications_cc_proto", "//mediapipe/tasks/cc/components/processors:classifier_options",
"//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto",
"//mediapipe/tasks/cc/core:base_options", "//mediapipe/tasks/cc/core:base_options",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"//mediapipe/tasks/cc/core:utils", "//mediapipe/tasks/cc/core:utils",
@ -61,7 +61,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/core:base_vision_task_api", "//mediapipe/tasks/cc/vision/core:base_vision_task_api",
"//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:running_mode",
"//mediapipe/tasks/cc/vision/core:vision_task_api_factory", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_options_cc_proto", "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],

View File

@ -26,9 +26,9 @@ limitations under the License.
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" #include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h" #include "mediapipe/tasks/cc/core/proto/inference_subgraph.pb.h"
@ -36,11 +36,12 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
namespace { namespace {
@ -52,12 +53,11 @@ constexpr char kImageTag[] = "IMAGE";
constexpr char kNormRectName[] = "norm_rect_in"; constexpr char kNormRectName[] = "norm_rect_in";
constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kSubgraphTypeName[] = constexpr char kSubgraphTypeName[] =
"mediapipe.tasks.vision.ImageClassifierGraph"; "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph";
constexpr int kMicroSecondsPerMilliSecond = 1000; constexpr int kMicroSecondsPerMilliSecond = 1000;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketMap;
using ImageClassifierOptionsProto =
image_classifier::proto::ImageClassifierOptions;
// Builds a NormalizedRect covering the entire image. // Builds a NormalizedRect covering the entire image.
NormalizedRect BuildFullImageNormRect() { NormalizedRect BuildFullImageNormRect() {
@ -70,17 +70,17 @@ NormalizedRect BuildFullImageNormRect() {
} }
// Creates a MediaPipe graph config that contains a subgraph node of // Creates a MediaPipe graph config that contains a subgraph node of
// "mediapipe.tasks.vision.ImageClassifierGraph". If the task is running in the // type "ImageClassifierGraph". If the task is running in the live stream mode,
// live stream mode, a "FlowLimiterCalculator" will be added to limit the number // a "FlowLimiterCalculator" will be added to limit the number of frames in
// of frames in flight. // flight.
CalculatorGraphConfig CreateGraphConfig( CalculatorGraphConfig CreateGraphConfig(
std::unique_ptr<ImageClassifierOptionsProto> options_proto, std::unique_ptr<proto::ImageClassifierGraphOptions> options_proto,
bool enable_flow_limiting) { bool enable_flow_limiting) {
api2::builder::Graph graph; api2::builder::Graph graph;
graph.In(kImageTag).SetName(kImageInStreamName); graph.In(kImageTag).SetName(kImageInStreamName);
graph.In(kNormRectTag).SetName(kNormRectName); graph.In(kNormRectTag).SetName(kNormRectName);
auto& task_subgraph = graph.AddNode(kSubgraphTypeName); auto& task_subgraph = graph.AddNode(kSubgraphTypeName);
task_subgraph.GetOptions<ImageClassifierOptionsProto>().Swap( task_subgraph.GetOptions<proto::ImageClassifierGraphOptions>().Swap(
options_proto.get()); options_proto.get());
task_subgraph.Out(kClassificationResultTag) task_subgraph.Out(kClassificationResultTag)
.SetName(kClassificationResultStreamName) >> .SetName(kClassificationResultStreamName) >>
@ -98,18 +98,18 @@ CalculatorGraphConfig CreateGraphConfig(
} }
// Converts the user-facing ImageClassifierOptions struct to the internal // Converts the user-facing ImageClassifierOptions struct to the internal
// ImageClassifierOptions proto. // ImageClassifierGraphOptions proto.
std::unique_ptr<ImageClassifierOptionsProto> std::unique_ptr<proto::ImageClassifierGraphOptions>
ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) { ConvertImageClassifierOptionsToProto(ImageClassifierOptions* options) {
auto options_proto = std::make_unique<ImageClassifierOptionsProto>(); auto options_proto = std::make_unique<proto::ImageClassifierGraphOptions>();
auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>( auto base_options_proto = std::make_unique<tasks::core::proto::BaseOptions>(
tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); tasks::core::ConvertBaseOptionsToProto(&(options->base_options)));
options_proto->mutable_base_options()->Swap(base_options_proto.get()); options_proto->mutable_base_options()->Swap(base_options_proto.get());
options_proto->mutable_base_options()->set_use_stream_mode( options_proto->mutable_base_options()->set_use_stream_mode(
options->running_mode != core::RunningMode::IMAGE); options->running_mode != core::RunningMode::IMAGE);
auto classifier_options_proto = auto classifier_options_proto =
std::make_unique<tasks::components::proto::ClassifierOptions>( std::make_unique<components::processors::proto::ClassifierOptions>(
components::ConvertClassifierOptionsToProto( components::processors::ConvertClassifierOptionsToProto(
&(options->classifier_options))); &(options->classifier_options)));
options_proto->mutable_classifier_options()->Swap( options_proto->mutable_classifier_options()->Swap(
classifier_options_proto.get()); classifier_options_proto.get());
@ -145,7 +145,7 @@ absl::StatusOr<std::unique_ptr<ImageClassifier>> ImageClassifier::Create(
}; };
} }
return core::VisionTaskApiFactory::Create<ImageClassifier, return core::VisionTaskApiFactory::Create<ImageClassifier,
ImageClassifierOptionsProto>( proto::ImageClassifierGraphOptions>(
CreateGraphConfig( CreateGraphConfig(
std::move(options_proto), std::move(options_proto),
options->running_mode == core::RunningMode::LIVE_STREAM), options->running_mode == core::RunningMode::LIVE_STREAM),
@ -214,6 +214,7 @@ absl::Status ImageClassifier::ClassifyAsync(Image image, int64 timestamp_ms,
.At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}});
} }
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -23,8 +23,8 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/classifier_options.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/processors/classifier_options.h"
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
@ -32,6 +32,7 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
// The options for configuring a Mediapipe image classifier task. // The options for configuring a Mediapipe image classifier task.
struct ImageClassifierOptions { struct ImageClassifierOptions {
@ -50,12 +51,14 @@ struct ImageClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
components::ClassifierOptions classifier_options; components::processors::ClassifierOptions classifier_options;
// The user-defined result callback for processing live stream data. // The user-defined result callback for processing live stream data.
// The result callback should only be specified when the running mode is set // The result callback should only be specified when the running mode is set
// to RunningMode::LIVE_STREAM. // to RunningMode::LIVE_STREAM.
std::function<void(absl::StatusOr<ClassificationResult>, const Image&, int64)> std::function<void(
absl::StatusOr<components::containers::proto::ClassificationResult>,
const Image&, int64)>
result_callback = nullptr; result_callback = nullptr;
}; };
@ -112,7 +115,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. // The image can be of any size with format RGB or RGBA.
// TODO: describe exact preprocessing steps once // TODO: describe exact preprocessing steps once
// YUVToImageCalculator is integrated. // YUVToImageCalculator is integrated.
absl::StatusOr<ClassificationResult> Classify( absl::StatusOr<components::containers::proto::ClassificationResult> Classify(
mediapipe::Image image, mediapipe::Image image,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
@ -126,8 +129,8 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
// The image can be of any size with format RGB or RGBA. It's required to // The image can be of any size with format RGB or RGBA. It's required to
// provide the video frame's timestamp (in milliseconds). The input timestamps // provide the video frame's timestamp (in milliseconds). The input timestamps
// must be monotonically increasing. // must be monotonically increasing.
absl::StatusOr<ClassificationResult> ClassifyForVideo( absl::StatusOr<components::containers::proto::ClassificationResult>
mediapipe::Image image, int64 timestamp_ms, ClassifyForVideo(mediapipe::Image image, int64 timestamp_ms,
std::optional<mediapipe::NormalizedRect> roi = std::nullopt); std::optional<mediapipe::NormalizedRect> roi = std::nullopt);
// Sends live image data to image classification, and the results will be // Sends live image data to image classification, and the results will be
@ -161,6 +164,7 @@ class ImageClassifier : tasks::vision::core::BaseVisionTaskApi {
absl::Status Close() { return runner_->Close(); } absl::Status Close() { return runner_->Close(); }
}; };
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -22,18 +22,19 @@ limitations under the License.
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h"
#include "mediapipe/tasks/cc/components/image_preprocessing.h" #include "mediapipe/tasks/cc/components/image_preprocessing.h"
#include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h" #include "mediapipe/tasks/cc/components/image_preprocessing_options.pb.h"
#include "mediapipe/tasks/cc/components/processors/classification_postprocessing_graph.h"
#include "mediapipe/tasks/cc/components/processors/proto/classification_postprocessing_graph_options.pb.h"
#include "mediapipe/tasks/cc/core/model_resources.h" #include "mediapipe/tasks/cc/core/model_resources.h"
#include "mediapipe/tasks/cc/core/model_task_graph.h" #include "mediapipe/tasks/cc/core/model_task_graph.h"
#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_options.pb.h" #include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
namespace { namespace {
@ -42,8 +43,7 @@ using ::mediapipe::api2::Output;
using ::mediapipe::api2::builder::GenericNode; using ::mediapipe::api2::builder::GenericNode;
using ::mediapipe::api2::builder::Graph; using ::mediapipe::api2::builder::Graph;
using ::mediapipe::api2::builder::Source; using ::mediapipe::api2::builder::Source;
using ImageClassifierOptionsProto = using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
image_classifier::proto::ImageClassifierOptions;
constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest();
@ -61,8 +61,7 @@ struct ImageClassifierOutputStreams {
} // namespace } // namespace
// A "mediapipe.tasks.vision.ImageClassifierGraph" performs image // An "ImageClassifierGraph" performs image classification.
// classification.
// - Accepts CPU input images and outputs classifications on CPU. // - Accepts CPU input images and outputs classifications on CPU.
// //
// Inputs: // Inputs:
@ -80,12 +79,12 @@ struct ImageClassifierOutputStreams {
// //
// Example: // Example:
// node { // node {
// calculator: "mediapipe.tasks.vision.ImageClassifierGraph" // calculator: "mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"
// input_stream: "IMAGE:image_in" // input_stream: "IMAGE:image_in"
// output_stream: "CLASSIFICATION_RESULT:classification_result_out" // output_stream: "CLASSIFICATION_RESULT:classification_result_out"
// output_stream: "IMAGE:image_out" // output_stream: "IMAGE:image_out"
// options { // options {
// [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierOptions.ext] // [mediapipe.tasks.vision.image_classifier.proto.ImageClassifierGraphOptions.ext]
// { // {
// base_options { // base_options {
// model_asset { // model_asset {
@ -104,13 +103,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
public: public:
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
ASSIGN_OR_RETURN(const auto* model_resources, ASSIGN_OR_RETURN(
CreateModelResources<ImageClassifierOptionsProto>(sc)); const auto* model_resources,
CreateModelResources<proto::ImageClassifierGraphOptions>(sc));
Graph graph; Graph graph;
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto output_streams, auto output_streams,
BuildImageClassificationTask( BuildImageClassificationTask(
sc->Options<ImageClassifierOptionsProto>(), *model_resources, sc->Options<proto::ImageClassifierGraphOptions>(), *model_resources,
graph[Input<Image>(kImageTag)], graph[Input<Image>(kImageTag)],
graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph)); graph[Input<NormalizedRect>::Optional(kNormRectTag)], graph));
output_streams.classification_result >> output_streams.classification_result >>
@ -125,13 +125,13 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// (mediapipe::Image) as input and returns one classification result per input // (mediapipe::Image) as input and returns one classification result per input
// image. // image.
// //
// task_options: the mediapipe tasks ImageClassifierOptions. // task_options: the mediapipe tasks ImageClassifierGraphOptions.
// model_resources: the ModelSources object initialized from an image // model_resources: the ModelSources object initialized from an image
// classification model file with model metadata. // classification model file with model metadata.
// image_in: (mediapipe::Image) stream to run classification on. // image_in: (mediapipe::Image) stream to run classification on.
// graph: the mediapipe builder::Graph instance to be updated. // graph: the mediapipe builder::Graph instance to be updated.
absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask( absl::StatusOr<ImageClassifierOutputStreams> BuildImageClassificationTask(
const ImageClassifierOptionsProto& task_options, const proto::ImageClassifierGraphOptions& task_options,
const core::ModelResources& model_resources, Source<Image> image_in, const core::ModelResources& model_resources, Source<Image> image_in,
Source<NormalizedRect> norm_rect_in, Graph& graph) { Source<NormalizedRect> norm_rect_in, Graph& graph) {
// Adds preprocessing calculators and connects them to the graph input image // Adds preprocessing calculators and connects them to the graph input image
@ -153,11 +153,14 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
// Adds postprocessing calculators and connects them to the graph output. // Adds postprocessing calculators and connects them to the graph output.
auto& postprocessing = graph.AddNode( auto& postprocessing = graph.AddNode(
"mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); "mediapipe.tasks.components.processors."
MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( "ClassificationPostprocessingGraph");
MP_RETURN_IF_ERROR(
components::processors::ConfigureClassificationPostprocessingGraph(
model_resources, task_options.classifier_options(), model_resources, task_options.classifier_options(),
&postprocessing.GetOptions< &postprocessing
tasks::components::ClassificationPostprocessingOptions>())); .GetOptions<components::processors::proto::
ClassificationPostprocessingGraphOptions>()));
inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag);
// Outputs the aggregated classification result as the subgraph output // Outputs the aggregated classification result as the subgraph output
@ -168,8 +171,10 @@ class ImageClassifierGraph : public core::ModelTaskGraph {
/*image=*/preprocessing[Output<Image>(kImageTag)]}; /*image=*/preprocessing[Output<Image>(kImageTag)]};
} }
}; };
REGISTER_MEDIAPIPE_GRAPH(::mediapipe::tasks::vision::ImageClassifierGraph); REGISTER_MEDIAPIPE_GRAPH(
::mediapipe::tasks::vision::image_classifier::ImageClassifierGraph);
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -32,8 +32,8 @@ limitations under the License.
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/tasks/cc/common.h" #include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/components/containers/category.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/category.pb.h"
#include "mediapipe/tasks/cc/components/containers/classifications.pb.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h"
#include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -44,9 +44,13 @@ limitations under the License.
namespace mediapipe { namespace mediapipe {
namespace tasks { namespace tasks {
namespace vision { namespace vision {
namespace image_classifier {
namespace { namespace {
using ::mediapipe::file::JoinPath; using ::mediapipe::file::JoinPath;
using ::mediapipe::tasks::components::containers::proto::ClassificationEntry;
using ::mediapipe::tasks::components::containers::proto::ClassificationResult;
using ::mediapipe::tasks::components::containers::proto::Classifications;
using ::testing::HasSubstr; using ::testing::HasSubstr;
using ::testing::Optional; using ::testing::Optional;
@ -814,6 +818,7 @@ TEST_F(LiveStreamModeTest, SucceedsWithRegionOfInterest) {
} }
} // namespace } // namespace
} // namespace image_classifier
} // namespace vision } // namespace vision
} // namespace tasks } // namespace tasks
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,12 +19,12 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"]) licenses(["notice"])
mediapipe_proto_library( mediapipe_proto_library(
name = "image_classifier_options_proto", name = "image_classifier_graph_options_proto",
srcs = ["image_classifier_options.proto"], srcs = ["image_classifier_graph_options.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/components/proto:classifier_options_proto", "//mediapipe/tasks/cc/components/processors/proto:classifier_options_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
], ],
) )

View File

@ -18,12 +18,12 @@ syntax = "proto2";
package mediapipe.tasks.vision.image_classifier.proto; package mediapipe.tasks.vision.image_classifier.proto;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; import "mediapipe/tasks/cc/components/processors/proto/classifier_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
message ImageClassifierOptions { message ImageClassifierGraphOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ImageClassifierOptions ext = 456383383; optional ImageClassifierGraphOptions ext = 456383383;
} }
// Base options for configuring MediaPipe Tasks, such as specifying the TfLite // Base options for configuring MediaPipe Tasks, such as specifying the TfLite
// model file with metadata, accelerator options, etc. // model file with metadata, accelerator options, etc.
@ -31,5 +31,5 @@ message ImageClassifierOptions {
// Options for configuring the classifier behavior, such as score threshold, // Options for configuring the classifier behavior, such as score threshold,
// number of results, etc. // number of results, etc.
optional components.proto.ClassifierOptions classifier_options = 2; optional components.processors.proto.ClassifierOptions classifier_options = 2;
} }

View File

@ -33,7 +33,6 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_options_cc_proto",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
], ],
) )
@ -73,19 +72,4 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "image_segmenter_op_resolvers",
srcs = ["image_segmenter_op_resolvers.cc"],
hdrs = ["image_segmenter_op_resolvers.h"],
deps = [
"//mediapipe/util/tflite/operations:landmarks_to_transform_matrix",
"//mediapipe/util/tflite/operations:max_pool_argmax",
"//mediapipe/util/tflite/operations:max_unpooling",
"//mediapipe/util/tflite/operations:transform_landmarks",
"//mediapipe/util/tflite/operations:transform_tensor_bilinear",
"//mediapipe/util/tflite/operations:transpose_conv_bias",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
)
# TODO: This test fails in OSS # TODO: This test fails in OSS

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/base_options.h" #include "mediapipe/tasks/cc/core/base_options.h"
#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" #include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
namespace mediapipe { namespace mediapipe {

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" #include "mediapipe/tasks/cc/core/proto/base_options.pb.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_op_resolvers.h"
#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_options.pb.h"
#include "mediapipe/tasks/cc/vision/utils/image_utils.h" #include "mediapipe/tasks/cc/vision/utils/image_utils.h"
#include "tensorflow/lite/core/shims/cc/shims_test_util.h" #include "tensorflow/lite/core/shims/cc/shims_test_util.h"
@ -260,8 +259,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata);
options->base_options.op_resolver =
absl::make_unique<SelfieSegmentationModelOpResolver>();
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::SOFTMAX; options->activation = ImageSegmenterOptions::Activation::SOFTMAX;
@ -290,8 +287,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) {
auto options = std::make_unique<ImageSegmenterOptions>(); auto options = std::make_unique<ImageSegmenterOptions>();
options->base_options.model_asset_path = options->base_options.model_asset_path =
JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata);
options->base_options.op_resolver =
absl::make_unique<SelfieSegmentationModelOpResolver>();
options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK;
options->activation = ImageSegmenterOptions::Activation::NONE; options->activation = ImageSegmenterOptions::Activation::NONE;
MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter, MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<ImageSegmenter> segmenter,

View File

@ -11,3 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "category",
srcs = ["Category.java"],
deps = [
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)
android_library(
name = "detection",
srcs = ["Detection.java"],
deps = [
":category",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,86 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import com.google.auto.value.AutoValue;
import java.util.Objects;
/**
* Category is a util class, contains a category name, its display name, a float value as score, and
* the index of the label in the corresponding label file. Typically it's used as result of
* classification or detection tasks.
*/
@AutoValue
public abstract class Category {
private static final float TOLERANCE = 1e-6f;
/**
* Creates a {@link Category} instance.
*
* @param score the probability score of this label category.
* @param index the index of the label in the corresponding label file.
* @param categoryName the label of this category object.
* @param displayName the display name of the label.
*/
public static Category create(float score, int index, String categoryName, String displayName) {
return new AutoValue_Category(score, index, categoryName, displayName);
}
/** The probability score of this label category. */
public abstract float score();
/** The index of the label in the corresponding label file. Returns -1 if the index is not set. */
public abstract int index();
/** The label of this category object. */
public abstract String categoryName();
/**
* The display name of the label, which may be translated for different locales. For example, a
* label, "apple", may be translated into Spanish for display purpose, so that the display name is
* "manzana".
*/
public abstract String displayName();
@Override
public final boolean equals(Object o) {
if (!(o instanceof Category)) {
return false;
}
Category other = (Category) o;
return Math.abs(other.score() - this.score()) < TOLERANCE
&& other.index() == this.index()
&& other.categoryName().equals(this.categoryName())
&& other.displayName().equals(this.displayName());
}
@Override
public final int hashCode() {
return Objects.hash(categoryName(), displayName(), score(), index());
}
@Override
public final String toString() {
return "<Category \""
+ categoryName()
+ "\" (displayName="
+ displayName()
+ " score="
+ score()
+ " index="
+ index()
+ ")>";
}
}

View File

@ -0,0 +1,50 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.components.containers;
import android.graphics.RectF;
import com.google.auto.value.AutoValue;
import java.util.Collections;
import java.util.List;
/**
* Represents one detected object in the results of {@link
* com.google.mediapipe.tasks.version.objectdetector.ObjectDetector}.
*/
@AutoValue
public abstract class Detection {
/**
* Creates a {@link Detection} instance from a list of {@link Category} and a bounding box.
*
* @param categories a list of {@link Category} objects that contain category name, display name,
* score, and the label index.
* @param boundingBox a {@link RectF} object to represent the bounding box.
*/
public static Detection create(List<Category> categories, RectF boundingBox) {
// As an open source project, we've been trying avoiding depending on common java libraries,
// such as Guava, because it may introduce conflicts with clients who also happen to use those
// libraries. Therefore, instead of using ImmutableList here, we convert the List into
// unmodifiableList
return new AutoValue_Detection(Collections.unmodifiableList(categories), boundingBox);
}
/** A list of {@link Category} objects. */
public abstract List<Category> categories();
/** A {@link RectF} object to represent the bounding box of the detected object. */
public abstract RectF boundingBox();
}

View File

@ -11,3 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["*.java"]),
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
"//mediapipe/framework:calculator_java_proto_lite",
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/tasks/cc/core/proto:acceleration_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/core/proto:external_file_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
"//third_party:autovalue",
"@com_google_protobuf//:protobuf_javalite",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,96 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import com.google.auto.value.AutoValue;
import java.nio.ByteBuffer;
import java.nio.MappedByteBuffer;
import java.util.Optional;
/** Options to configure MediaPipe Tasks in general. */
@AutoValue
public abstract class BaseOptions {
/** Builder for {@link BaseOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/**
* Sets the model path to a tflite model with metadata in the assets.
*
* <p>Note: when model path is set, both model file descriptor and model buffer should be empty.
*/
public abstract Builder setModelAssetPath(String value);
/**
* Sets the native fd int of a tflite model with metadata.
*
* <p>Note: when model file descriptor is set, both model path and model buffer should be empty.
*/
public abstract Builder setModelAssetFileDescriptor(Integer value);
/**
* Sets either the direct {@link ByteBuffer} or the {@link MappedByteBuffer} of a tflite model
* with metadata.
*
* <p>Note: when model buffer is set, both model file and model file descriptor should be empty.
*/
public abstract Builder setModelAssetBuffer(ByteBuffer value);
/**
* Sets device Delegate to run the MediaPipe pipeline. If the delegate is not set, default
* delegate CPU is used.
*/
public abstract Builder setDelegate(Delegate delegate);
abstract BaseOptions autoBuild();
/**
* Validates and builds the {@link BaseOptions} instance.
*
* @throws IllegalArgumentException if {@link BaseOptions} is invalid, or the provided model
* buffer is not a direct {@link ByteBuffer} or a {@link MappedByteBuffer}.
*/
public final BaseOptions build() {
BaseOptions options = autoBuild();
int modelAssetPathPresent = options.modelAssetPath().isPresent() ? 1 : 0;
int modelAssetFileDescriptorPresent = options.modelAssetFileDescriptor().isPresent() ? 1 : 0;
int modelAssetBufferPresent = options.modelAssetBuffer().isPresent() ? 1 : 0;
if (modelAssetPathPresent + modelAssetFileDescriptorPresent + modelAssetBufferPresent != 1) {
throw new IllegalArgumentException(
"Please specify only one of the model asset path, the model asset file descriptor, and"
+ " the model asset buffer.");
}
if (options.modelAssetBuffer().isPresent()
&& !(options.modelAssetBuffer().get().isDirect()
|| options.modelAssetBuffer().get() instanceof MappedByteBuffer)) {
throw new IllegalArgumentException(
"The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
}
return options;
}
}
abstract Optional<String> modelAssetPath();
abstract Optional<Integer> modelAssetFileDescriptor();
abstract Optional<ByteBuffer> modelAssetBuffer();
abstract Delegate delegate();
public static Builder builder() {
return new AutoValue_BaseOptions.Builder().setDelegate(Delegate.CPU);
}
}

View File

@ -0,0 +1,22 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
/** MediaPipe Tasks delegate. */
// TODO implement advanced delegate setting.
public enum Delegate {
CPU,
GPU,
}

View File

@ -0,0 +1,20 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
/** Interface for the customizable MediaPipe task error listener. */
public interface ErrorListener {
void onError(RuntimeException e);
}

View File

@ -0,0 +1,49 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import java.util.concurrent.atomic.AtomicBoolean;
/** Facilitates creation and destruction of the native ModelResourcesCache. */
class ModelResourcesCache {
private final long nativeHandle;
private final AtomicBoolean isHandleValid;
public ModelResourcesCache() {
nativeHandle = nativeCreateModelResourcesCache();
isHandleValid = new AtomicBoolean(true);
}
public boolean isHandleValid() {
return isHandleValid.get();
}
public long getNativeHandle() {
if (isHandleValid.get()) {
return nativeHandle;
}
return 0;
}
public void release() {
if (isHandleValid.compareAndSet(true, false)) {
nativeReleaseModelResourcesCache(nativeHandle);
}
}
private native long nativeCreateModelResourcesCache();
private native void nativeReleaseModelResourcesCache(long nativeHandle);
}

View File

@ -0,0 +1,29 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import com.google.mediapipe.framework.GraphService;
/** Java wrapper for graph service of ModelResourcesCacheService. */
class ModelResourcesCacheService implements GraphService<ModelResourcesCache> {
public ModelResourcesCacheService() {}
@Override
public void installServiceObject(long context, ModelResourcesCache object) {
nativeInstallServiceObject(context, object.getNativeHandle());
}
public native void nativeInstallServiceObject(long context, long object);
}

View File

@ -0,0 +1,130 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import android.util.Log;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import java.util.List;
/** Base class for handling MediaPipe task graph outputs. */
public class OutputHandler<OutputT extends TaskResult, InputT> {
/**
* Interface for converting MediaPipe graph output {@link Packet}s to task result object and task
* input object.
*/
public interface OutputPacketConverter<OutputT extends TaskResult, InputT> {
OutputT convertToTaskResult(List<Packet> packets);
InputT convertToTaskInput(List<Packet> packets);
}
/** Interface for the customizable MediaPipe task result listener. */
public interface ResultListener<OutputT extends TaskResult, InputT> {
void run(OutputT result, InputT input);
}
private static final String TAG = "OutputHandler";
// A task-specific graph output packet converter that should be implemented per task.
private OutputPacketConverter<OutputT, InputT> outputPacketConverter;
// The user-defined task result listener.
private ResultListener<OutputT, InputT> resultListener;
// The user-defined error listener.
protected ErrorListener errorListener;
// The cached task result for non latency sensitive use cases.
protected OutputT cachedTaskResult;
// Whether the output handler should react to timestamp-bound changes by outputting empty packets.
private boolean handleTimestampBoundChanges = false;
/**
* Sets a callback to be invoked to convert a {@link Packet} list to a task result object and a
* task input object.
*
* @param converter the task-specific {@link OutputPacketConverter} callback.
*/
public void setOutputPacketConverter(OutputPacketConverter<OutputT, InputT> converter) {
this.outputPacketConverter = converter;
}
/**
* Sets a callback to be invoked when task result objects become available.
*
* @param listener the user-defined {@link ResultListener} callback.
*/
public void setResultListener(ResultListener<OutputT, InputT> listener) {
this.resultListener = listener;
}
/**
* Sets a callback to be invoked when exceptions are thrown from the task graph.
*
* @param listener The user-defined {@link ErrorListener} callback.
*/
public void setErrorListener(ErrorListener listener) {
this.errorListener = listener;
}
/**
* Sets whether the output handler should react to the timestamp bound changes that are reprsented
* as empty output {@link Packet}s.
*
* @param handleTimestampBoundChanges A boolean value.
*/
public void setHandleTimestampBoundChanges(boolean handleTimestampBoundChanges) {
this.handleTimestampBoundChanges = handleTimestampBoundChanges;
}
/** Returns true if the task graph is set to handle timestamp bound changes. */
boolean handleTimestampBoundChanges() {
return handleTimestampBoundChanges;
}
/* Returns the cached task result object. */
public OutputT retrieveCachedTaskResult() {
OutputT taskResult = cachedTaskResult;
cachedTaskResult = null;
return taskResult;
}
/**
* Handles a list of output {@link Packet}s. Invoked when a packet list become available.
*
* @param packets A list of output {@link Packet}s.
*/
void run(List<Packet> packets) {
OutputT taskResult = null;
try {
taskResult = outputPacketConverter.convertToTaskResult(packets);
if (resultListener == null) {
cachedTaskResult = taskResult;
} else {
InputT taskInput = outputPacketConverter.convertToTaskInput(packets);
resultListener.run(taskResult, taskInput);
}
} catch (MediaPipeException e) {
if (errorListener != null) {
errorListener.onError(e);
} else {
Log.e(TAG, "Error occurs when getting MediaPipe vision task result. " + e);
}
} finally {
for (Packet packet : packets) {
if (packet != null) {
packet.release();
}
}
}
}
}

View File

@ -0,0 +1,156 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig;
import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig.Node;
import com.google.mediapipe.proto.CalculatorProto.InputStreamInfo;
import com.google.mediapipe.calculator.proto.FlowLimiterCalculatorProto.FlowLimiterCalculatorOptions;
import java.util.ArrayList;
import java.util.List;
/**
* {@link TaskInfo} contains all needed informaton to initialize a MediaPipe Task {@link
* com.google.mediapipe.framework.Graph}.
*/
@AutoValue
public abstract class TaskInfo<T extends TaskOptions> {
/** Builder for {@link TaskInfo}. */
@AutoValue.Builder
public abstract static class Builder<T extends TaskOptions> {
/** Sets the MediaPipe task graph name. */
public abstract Builder<T> setTaskGraphName(String value);
/** Sets a list of task graph input stream info {@link String}s in the form TAG:name. */
public abstract Builder<T> setInputStreams(List<String> value);
/** Sets a list of task graph output stream info {@link String}s in the form TAG:name. */
public abstract Builder<T> setOutputStreams(List<String> value);
/** Sets to true if the task requires a flow limiter. */
public abstract Builder<T> setEnableFlowLimiting(Boolean value);
/**
* Sets a task-specific options instance.
*
* @param value a task-specific options that is derived from {@link TaskOptions}.
*/
public abstract Builder<T> setTaskOptions(T value);
public abstract TaskInfo<T> autoBuild();
/**
* Validates and builds the {@link TaskInfo} instance. *
*
* @throws IllegalArgumentException if the required information such as task graph name, graph
* input streams, and the graph output streams are empty.
*/
public final TaskInfo<T> build() {
TaskInfo<T> taskInfo = autoBuild();
if (taskInfo.taskGraphName().isEmpty()
|| taskInfo.inputStreams().isEmpty()
|| taskInfo.outputStreams().isEmpty()) {
throw new IllegalArgumentException(
"Task graph's name, input streams, and output streams should be non-empty.");
}
return taskInfo;
}
}
abstract String taskGraphName();
abstract T taskOptions();
abstract List<String> inputStreams();
abstract List<String> outputStreams();
abstract Boolean enableFlowLimiting();
public static <T extends TaskOptions> Builder<T> builder() {
return new AutoValue_TaskInfo.Builder<T>();
}
/* Returns a list of the output stream names without the stream tags. */
List<String> outputStreamNames() {
List<String> streamNames = new ArrayList<>(outputStreams().size());
for (String stream : outputStreams()) {
streamNames.add(stream.substring(stream.lastIndexOf(':') + 1));
}
return streamNames;
}
/**
* Creates a MediaPipe Task {@link CalculatorGraphConfig} protobuf message from the {@link
* TaskInfo} instance.
*/
CalculatorGraphConfig generateGraphConfig() {
CalculatorGraphConfig.Builder graphBuilder = CalculatorGraphConfig.newBuilder();
Node.Builder taskSubgraphBuilder =
Node.newBuilder()
.setCalculator(taskGraphName())
.setOptions(taskOptions().convertToCalculatorOptionsProto());
for (String outputStream : outputStreams()) {
taskSubgraphBuilder.addOutputStream(outputStream);
graphBuilder.addOutputStream(outputStream);
}
if (!enableFlowLimiting()) {
for (String inputStream : inputStreams()) {
taskSubgraphBuilder.addInputStream(inputStream);
graphBuilder.addInputStream(inputStream);
}
graphBuilder.addNode(taskSubgraphBuilder.build());
return graphBuilder.build();
}
Node.Builder flowLimiterCalculatorBuilder =
Node.newBuilder()
.setCalculator("FlowLimiterCalculator")
.addInputStreamInfo(
InputStreamInfo.newBuilder().setTagIndex("FINISHED").setBackEdge(true).build())
.setOptions(
CalculatorOptions.newBuilder()
.setExtension(
FlowLimiterCalculatorOptions.ext,
FlowLimiterCalculatorOptions.newBuilder()
.setMaxInFlight(1)
.setMaxInQueue(1)
.build())
.build());
for (String inputStream : inputStreams()) {
graphBuilder.addInputStream(inputStream);
flowLimiterCalculatorBuilder.addInputStream(stripTagIndex(inputStream));
String taskInputStream = addStreamNamePrefix(inputStream);
flowLimiterCalculatorBuilder.addOutputStream(stripTagIndex(taskInputStream));
taskSubgraphBuilder.addInputStream(taskInputStream);
}
flowLimiterCalculatorBuilder.addInputStream(
"FINISHED:" + stripTagIndex(outputStreams().get(0)));
graphBuilder.addNode(flowLimiterCalculatorBuilder.build());
graphBuilder.addNode(taskSubgraphBuilder.build());
return graphBuilder.build();
}
private String stripTagIndex(String tagIndexName) {
return tagIndexName.substring(tagIndexName.lastIndexOf(':') + 1);
}
private String addStreamNamePrefix(String tagIndexName) {
return tagIndexName.substring(0, tagIndexName.lastIndexOf(':') + 1)
+ "throttled_"
+ stripTagIndex(tagIndexName);
}
}

View File

@ -0,0 +1,75 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package com.google.mediapipe.tasks.core;
import com.google.mediapipe.calculator.proto.InferenceCalculatorProto;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.tasks.core.proto.AccelerationProto;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.core.proto.ExternalFileProto;
import com.google.protobuf.ByteString;
/**
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
* {@link TaskOptions}.
*/
public abstract class TaskOptions {
/**
* Converts a MediaPipe Tasks task-specific options to a {@link CalculatorOptions} protobuf
* message.
*/
public abstract CalculatorOptions convertToCalculatorOptionsProto();
/**
* Converts a {@link BaseOptions} instance to a {@link BaseOptionsProto.BaseOptions} protobuf
* message.
*/
protected BaseOptionsProto.BaseOptions convertBaseOptionsToProto(BaseOptions options) {
ExternalFileProto.ExternalFile.Builder externalFileBuilder =
ExternalFileProto.ExternalFile.newBuilder();
options.modelAssetPath().ifPresent(externalFileBuilder::setFileName);
options
.modelAssetFileDescriptor()
.ifPresent(
fd ->
externalFileBuilder.setFileDescriptorMeta(
ExternalFileProto.FileDescriptorMeta.newBuilder().setFd(fd).build()));
options
.modelAssetBuffer()
.ifPresent(
modelBuffer -> {
modelBuffer.rewind();
externalFileBuilder.setFileContent(ByteString.copyFrom(modelBuffer));
});
AccelerationProto.Acceleration.Builder accelerationBuilder =
AccelerationProto.Acceleration.newBuilder();
switch (options.delegate()) {
case CPU:
accelerationBuilder.setXnnpack(
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Xnnpack
.getDefaultInstance());
break;
case GPU:
accelerationBuilder.setGpu(
InferenceCalculatorProto.InferenceCalculatorOptions.Delegate.Gpu.getDefaultInstance());
break;
}
return BaseOptionsProto.BaseOptions.newBuilder()
.setModelAsset(externalFileBuilder.build())
.setAcceleration(accelerationBuilder.build())
.build();
}
}

View File

@ -0,0 +1,24 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
/**
* Interface for the MediaPipe Task result. Any MediaPipe task-specific result class should
* implement {@link TaskResult}.
*/
public interface TaskResult {
/** Returns the timestamp that is associated with the task result object. */
long timestampMs();
}

View File

@ -0,0 +1,265 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import android.content.Context;
import android.util.Log;
import com.google.mediapipe.framework.AndroidAssetUtil;
import com.google.mediapipe.framework.AndroidPacketCreator;
import com.google.mediapipe.framework.Graph;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
/** The runner of MediaPipe task graphs. */
public class TaskRunner implements AutoCloseable {
private static final String TAG = TaskRunner.class.getSimpleName();
private static final long TIMESATMP_UNITS_PER_SECOND = 1000000;
private final OutputHandler<? extends TaskResult, ?> outputHandler;
private final AtomicBoolean graphStarted = new AtomicBoolean(false);
private final Graph graph;
private final ModelResourcesCache modelResourcesCache;
private final AndroidPacketCreator packetCreator;
private long lastSeenTimestamp = Long.MIN_VALUE;
private ErrorListener errorListener;
/**
* Create a {@link TaskRunner} instance.
*
* @param context an Android {@link Context}.
* @param taskInfo a {@link TaskInfo} instance contains task graph name, task options, and graph
* input and output stream names.
* @param outputHandler a {@link OutputHandler} instance handles task result object and runtime
* exception.
* @throws MediaPipeException for any error during {@link TaskRunner} creation.
*/
public static TaskRunner create(
Context context,
TaskInfo<? extends TaskOptions> taskInfo,
OutputHandler<? extends TaskResult, ?> outputHandler) {
AndroidAssetUtil.initializeNativeAssetManager(context);
Graph mediapipeGraph = new Graph();
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
ModelResourcesCache graphModelResourcesCache = new ModelResourcesCache();
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
mediapipeGraph.addMultiStreamCallback(
taskInfo.outputStreamNames(),
outputHandler::run,
/*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges());
mediapipeGraph.startRunningGraph();
// Waits until all calculators are opened and the graph is fully started.
mediapipeGraph.waitUntilGraphIdle();
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler);
}
/**
* Sets a callback to be invoked when exceptions are thrown by the {@link TaskRunner} instance.
*
* @param listener an {@link ErrorListener} callback.
*/
public void setErrorListener(ErrorListener listener) {
this.errorListener = listener;
}
/** Returns the {@link AndroidPacketCreator} associated to the {@link TaskRunner} instance. */
public AndroidPacketCreator getPacketCreator() {
return packetCreator;
}
/**
* A synchronous method for processing batch data.
*
* <p>Note: This method is designed for processing batch data such as unrelated images and texts.
* The call blocks the current thread until a failure status or a successful result is returned.
* An internal timestamp will be assigend per invocation. This method is thread-safe and allows
* clients to call it from different threads.
*
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
*/
public synchronized TaskResult process(Map<String, Packet> inputs) {
addPackets(inputs, generateSyntheticTimestamp());
graph.waitUntilGraphIdle();
return outputHandler.retrieveCachedTaskResult();
}
/**
* A synchronous method for processing offline streaming data.
*
* <p>Note: This method is designed for processing offline streaming data such as the decoded
* frames from a video file and an audio file. The call blocks the current thread until a failure
* status or a successful result is returned. The caller must ensure that the input timestamp is
* greater than the timestamps of previous invocations. This method is thread-unsafe and it is the
* caller's responsibility to synchronize access to this method across multiple threads and to
* ensure that the input packet timestamps are in order.
*
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
* @param inputTimestamp the timestamp of the input packets.
*/
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
addPackets(inputs, inputTimestamp);
graph.waitUntilGraphIdle();
return outputHandler.retrieveCachedTaskResult();
}
/**
* An asynchronous method for handling live streaming data.
*
* <p>Note: This method that is designed for handling live streaming data such as live camera and
* microphone data. A user-defined packets callback function must be provided in the constructor
* to receive the output packets. The caller must ensure that the input packet timestamps are
* monotonically increasing. This method is thread-unsafe and it is the caller's responsibility to
* synchronize access to this method across multiple threads and to ensure that the input packet
* timestamps are in order.
*
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
* @param inputTimestamp the timestamp of the input packets.
*/
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
validateInputTimstamp(inputTimestamp);
addPackets(inputs, inputTimestamp);
}
/**
* Resets and restarts the {@link TaskRunner} instance. This can be useful for resetting a
* stateful task graph to process new data.
*/
public void restart() {
if (graphStarted.get()) {
try {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
} catch (MediaPipeException e) {
reportError(e);
}
}
try {
graph.startRunningGraph();
// Waits until all calculators are opened and the graph is fully restarted.
graph.waitUntilGraphIdle();
graphStarted.set(true);
} catch (MediaPipeException e) {
reportError(e);
}
}
/** Closes and cleans up the {@link TaskRunner} instance. */
@Override
public void close() {
if (!graphStarted.get()) {
return;
}
try {
graphStarted.set(false);
graph.closeAllPacketSources();
graph.waitUntilGraphDone();
if (modelResourcesCache != null) {
modelResourcesCache.release();
}
} catch (MediaPipeException e) {
// Note: errors during Process are reported at the earliest opportunity,
// which may be addPacket or waitUntilDone, depending on timing. For consistency,
// we want to always report them using the same async handler if installed.
reportError(e);
}
try {
graph.tearDown();
} catch (MediaPipeException e) {
reportError(e);
}
}
private synchronized void addPackets(Map<String, Packet> inputs, long inputTimestamp) {
if (!graphStarted.get()) {
reportError(
new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"The task graph hasn't been successfully started or error occurs during graph"
+ " initializaton."));
}
try {
for (Map.Entry<String, Packet> entry : inputs.entrySet()) {
// addConsumablePacketToInputStream allows the graph to take exclusive ownership of the
// packet, which may allow for more memory optimizations.
graph.addConsumablePacketToInputStream(entry.getKey(), entry.getValue(), inputTimestamp);
// If addConsumablePacket succeeded, we don't need to release the packet ourselves.
entry.setValue(null);
}
} catch (MediaPipeException e) {
// TODO: do not suppress exceptions here!
if (errorListener == null) {
Log.e(TAG, "Mediapipe error: ", e);
} else {
throw e;
}
} finally {
for (Packet packet : inputs.values()) {
// In case of error, addConsumablePacketToInputStream will not release the packet, so we
// have to release it ourselves.
if (packet != null) {
packet.release();
}
}
}
}
/**
* Checks if the input timestamp is strictly greater than the last timestamp that has been
* processed.
*
* @param inputTimestamp the input timestamp.
*/
private void validateInputTimstamp(long inputTimestamp) {
if (lastSeenTimestamp >= inputTimestamp) {
reportError(
new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"The received packets having a smaller timestamp than the processed timestamp."));
}
lastSeenTimestamp = inputTimestamp;
}
/** Generates a synthetic input timestamp in the batch processing mode. */
private long generateSyntheticTimestamp() {
long timestamp =
lastSeenTimestamp == Long.MIN_VALUE ? 0 : lastSeenTimestamp + TIMESATMP_UNITS_PER_SECOND;
lastSeenTimestamp = timestamp;
return timestamp;
}
/** Private constructor. */
private TaskRunner(
Graph graph,
ModelResourcesCache modelResourcesCache,
OutputHandler<? extends TaskResult, ?> outputHandler) {
this.outputHandler = outputHandler;
this.graph = graph;
this.modelResourcesCache = modelResourcesCache;
this.packetCreator = new AndroidPacketCreator(graph);
graphStarted.set(true);
}
/** Reports error. */
private void reportError(MediaPipeException e) {
if (errorListener != null) {
errorListener.onError(e);
} else {
throw e;
}
}
}

View File

@ -0,0 +1,41 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library_with_tflite(
name = "model_resources_cache_jni",
srcs = [
"model_resources_cache_jni.cc",
],
hdrs = [
"model_resources_cache_jni.h",
],
tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources_cache",
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
] + select({
"//conditions:default": ["//third_party/java/jdk:jni"],
"//mediapipe:android": [],
}),
alwayslink = 1,
)

View File

@ -0,0 +1,72 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library_with_tflite(
name = "model_resources_cache_jni",
srcs = [
"model_resources_cache_jni.cc",
],
hdrs = [
"model_resources_cache_jni.h",
] + select({
# The Android toolchain makes "jni.h" available in the include path.
# For non-Android toolchains, generate jni.h and jni_md.h.
"//mediapipe:android": [],
"//conditions:default": [
":jni.h",
":jni_md.h",
],
}),
tflite_deps = [
"//mediapipe/tasks/cc/core:model_resources_cache",
],
deps = [
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
] + select({
"//conditions:default": [],
"//mediapipe:android": [],
}),
alwayslink = 1,
)
# Silly rules to make
# #include <jni.h>
# in the source headers work
# (in combination with the "includes" attribute of the tf_cuda_library rule
# above. Not needed when using the Android toolchain).
#
# Inspired from:
# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD
# but hopefully there is a simpler alternative to this.
genrule(
name = "copy_jni_h",
srcs = ["@bazel_tools//tools/jdk:jni_header"],
outs = ["jni.h"],
cmd = "cp -f $< $@",
)
genrule(
name = "copy_jni_md_h",
srcs = select({
"//mediapipe:macos": ["@bazel_tools//tools/jdk:jni_md_header-darwin"],
"//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"],
}),
outs = ["jni_md.h"],
cmd = "cp -f $< $@",
)

View File

@ -0,0 +1,51 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni/model_resources_cache_jni.h"
#include <utility>
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_service_jni.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/model_resources_cache.h"
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
namespace {
using ::mediapipe::tasks::core::kModelResourcesCacheService;
using ::mediapipe::tasks::core::MediaPipeBuiltinOpResolver;
using ::mediapipe::tasks::core::ModelResourcesCache;
using HandleType = std::shared_ptr<ModelResourcesCache>*;
} // namespace
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz) {
auto ptr = std::make_shared<ModelResourcesCache>(
absl::make_unique<MediaPipeBuiltinOpResolver>());
HandleType handle = new std::shared_ptr<ModelResourcesCache>(std::move(ptr));
return reinterpret_cast<jlong>(handle);
}
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
jlong nativeHandle) {
delete reinterpret_cast<HandleType>(nativeHandle);
}
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
jlong objectHandle) {
mediapipe::android::GraphServiceHelper::SetServiceObject(
contextHandle, kModelResourcesCacheService,
*reinterpret_cast<HandleType>(objectHandle));
}

View File

@ -0,0 +1,45 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
#define JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
#define MODEL_RESOURCES_CACHE_METHOD(METHOD_NAME) \
Java_com_google_mediapipe_tasks_core_ModelResourcesCache_##METHOD_NAME
#define MODEL_RESOURCES_CACHE_SERVICE_METHOD(METHOD_NAME) \
Java_com_google_mediapipe_tasks_core_ModelResourcesCacheService_##METHOD_NAME
JNIEXPORT jlong JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeCreateModelResourcesCache)(JNIEnv* env, jobject thiz);
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_METHOD(
nativeReleaseModelResourcesCache)(JNIEnv* env, jobject thiz,
jlong nativeHandle);
JNIEXPORT void JNICALL MODEL_RESOURCES_CACHE_SERVICE_METHOD(
nativeInstallServiceObject)(JNIEnv* env, jobject thiz, jlong contextHandle,
jlong objectHandle);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // JAVA_COM_GOOGLE_MEDIAPIPE_TASKS_CORE_JNI_MODEL_RESOURCES_CACHE_JNI_H_

View File

@ -11,3 +11,38 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/tasks:internal"])
android_library(
name = "core",
srcs = glob(["*.java"]),
deps = [
":libmediapipe_tasks_vision_jni_lib",
"//mediapipe/java/com/google/mediapipe/framework:android_framework_no_mff",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"@maven//:com_google_guava_guava",
],
)
# The native library of all MediaPipe vision tasks.
cc_binary(
name = "libmediapipe_tasks_vision_jni.so",
linkshared = 1,
linkstatic = 1,
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni",
],
)
cc_library(
name = "libmediapipe_tasks_vision_jni_lib",
srcs = [":libmediapipe_tasks_vision_jni.so"],
alwayslink = 1,
)

View File

@ -0,0 +1,114 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.core;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.tasks.core.TaskRunner;
import java.util.HashMap;
import java.util.Map;
/** The base class of MediaPipe vision tasks. */
public class BaseVisionTaskApi implements AutoCloseable {
private static final long MICROSECONDS_PER_MILLISECOND = 1000;
private final TaskRunner runner;
private final RunningMode runningMode;
static {
System.loadLibrary("mediapipe_tasks_vision_jni");
}
/**
* Constructor to initialize an {@link BaseVisionTaskApi} from a {@link TaskRunner} and a vision
* task {@link RunningMode}.
*
* @param runner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
public BaseVisionTaskApi(TaskRunner runner, RunningMode runningMode) {
this.runner = runner;
this.runningMode = runningMode;
}
/**
* A synchronous method to process single image inputs. The call blocks the current thread until a
* failure status or a successful result is returned.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if the task is not in the image mode.
*/
protected TaskResult processImageData(String imageStreamName, Image image) {
if (runningMode != RunningMode.IMAGE) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the image mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets);
}
/**
* A synchronous method to process continuous video frames. The call blocks the current thread
* until a failure status or a successful result is returned.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode.
*/
protected TaskResult processVideoData(String imageStreamName, Image image, long timestampMs) {
if (runningMode != RunningMode.VIDEO) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the video mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
return runner.process(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/**
* An asynchronous method to send live stream data to the {@link TaskRunner}. The results will be
* available in the user-defined result listener.
*
* @param imageStreamName the image input stream name.
* @param image a MediaPipe {@link Image} object for processing.
* @param timestampMs the corresponding timestamp of the input image in milliseconds.
* @throws MediaPipeException if the task is not in the video mode.
*/
protected void sendLiveStreamData(String imageStreamName, Image image, long timestampMs) {
if (runningMode != RunningMode.LIVE_STREAM) {
throw new MediaPipeException(
MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(),
"Task is not initialized with the live stream mode. Current running mode:"
+ runningMode.name());
}
Map<String, Packet> inputPackets = new HashMap<>();
inputPackets.put(imageStreamName, runner.getPacketCreator().createImage(image));
runner.send(inputPackets, timestampMs * MICROSECONDS_PER_MILLISECOND);
}
/** Closes and cleans up the MediaPipe vision task. */
@Override
public void close() {
runner.close();
}
}

View File

@ -0,0 +1,32 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.core;
/**
* MediaPipe vision task running mode. A MediaPipe vision task can be run with three different
* modes:
*
* <ul>
* <li>IMAGE: The mode for running a mediapipe vision task on single image inputs.
* <li>VIDEO: The mode for running a mediapipe vision task on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for running a mediapipe vision task on a live stream of input data,
* such as from camera.
* </ul>
*/
public enum RunningMode {
IMAGE,
VIDEO,
LIVE_STREAM
}

View File

@ -11,3 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "objectdetector",
srcs = [
"ObjectDetectionResult.java",
"ObjectDetector.java",
],
javacopts = [
"-Xep:AndroidJdkLibsChecker:OFF",
],
manifest = ":AndroidManifest.xml",
deps = [
"//mediapipe/framework:calculator_options_java_proto_lite",
"//mediapipe/framework/formats:detection_java_proto_lite",
"//mediapipe/framework/formats:location_data_java_proto_lite",
"//mediapipe/java/com/google/mediapipe/framework:android_framework",
"//mediapipe/java/com/google/mediapipe/framework/image",
"//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:category",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/core",
"//mediapipe/tasks/java/com/google/mediapipe/tasks/vision/core",
"//third_party:autovalue",
"@maven//:com_google_guava_guava",
],
)

View File

@ -0,0 +1,75 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import android.graphics.RectF;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.core.TaskResult;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.mediapipe.formats.proto.LocationDataProto.LocationData.BoundingBox;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/** Represents the detection results generated by {@link ObjectDetector}. */
@AutoValue
public abstract class ObjectDetectionResult implements TaskResult {
private static final int DEFAULT_CATEGORY_INDEX = -1;
@Override
public abstract long timestampMs();
public abstract List<com.google.mediapipe.tasks.components.containers.Detection> detections();
/**
* Creates an {@link ObjectDetectionResult} instance from a list of {@link Detection} protobuf
* messages.
*
* @param detectionList a list of {@link Detection} protobuf messages.
*/
static ObjectDetectionResult create(List<Detection> detectionList, long timestampMs) {
List<com.google.mediapipe.tasks.components.containers.Detection> detections = new ArrayList<>();
for (Detection detectionProto : detectionList) {
List<Category> categories = new ArrayList<>();
for (int idx = 0; idx < detectionProto.getScoreCount(); ++idx) {
categories.add(
Category.create(
detectionProto.getScore(idx),
detectionProto.getLabelIdCount() > idx
? detectionProto.getLabelId(idx)
: DEFAULT_CATEGORY_INDEX,
detectionProto.getLabelCount() > idx ? detectionProto.getLabel(idx) : "",
detectionProto.getDisplayNameCount() > idx
? detectionProto.getDisplayName(idx)
: ""));
}
RectF boundingBox = new RectF();
if (detectionProto.getLocationData().hasBoundingBox()) {
BoundingBox boundingBoxProto = detectionProto.getLocationData().getBoundingBox();
boundingBox.set(
/*left=*/ boundingBoxProto.getXmin(),
/*top=*/ boundingBoxProto.getYmin(),
/*right=*/ boundingBoxProto.getXmin() + boundingBoxProto.getWidth(),
/*bottom=*/ boundingBoxProto.getYmin() + boundingBoxProto.getHeight());
}
detections.add(
com.google.mediapipe.tasks.components.containers.Detection.create(
categories, boundingBox));
}
return new AutoValue_ObjectDetectionResult(
timestampMs, Collections.unmodifiableList(detections));
}
}

View File

@ -0,0 +1,418 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import android.content.Context;
import android.os.ParcelFileDescriptor;
import com.google.auto.value.AutoValue;
import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions;
import com.google.mediapipe.framework.AndroidPacketGetter;
import com.google.mediapipe.framework.Packet;
import com.google.mediapipe.framework.PacketGetter;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.ErrorListener;
import com.google.mediapipe.tasks.core.OutputHandler;
import com.google.mediapipe.tasks.core.OutputHandler.ResultListener;
import com.google.mediapipe.tasks.core.TaskInfo;
import com.google.mediapipe.tasks.core.TaskOptions;
import com.google.mediapipe.tasks.core.TaskRunner;
import com.google.mediapipe.tasks.core.proto.BaseOptionsProto;
import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* Performs object detection on images.
*
* <p>The API expects a TFLite model with <a
* href="https://www.tensorflow.org/lite/convert/metadata">TFLite Model Metadata.</a>.
*
* <p>The API supports models with one image input tensor and four output tensors. To be more
* specific, here are the requirements.
*
* <ul>
* <li>Input image tensor ({@code kTfLiteUInt8}/{@code kTfLiteFloat32})
* <ul>
* <li>image input of size {@code [batch x height x width x channels]}.
* <li>batch inference is not supported ({@code batch} is required to be 1).
* <li>only RGB inputs are supported ({@code channels} is required to be 3).
* <li>if type is {@code kTfLiteFloat32}, NormalizationOptions are required to be attached
* to the metadata for input normalization.
* </ul>
* <li>Output tensors must be the 4 outputs of a {@code DetectionPostProcess} op, i.e:
* <ul>
* <li>Location tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results x 4]}, the inner array representing
* bounding boxes in the form [top, left, right, bottom].
* <li>{@code BoundingBoxProperties} are required to be attached to the metadata and
* must specify {@code type=BOUNDARIES} and {@code coordinate_type=RATIO}.
* </ul>
* <li>Classes tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results]}, each value representing the integer
* index of a class.
* <li>if label maps are attached to the metadata as {@code TENSOR_VALUE_LABELS}
* associated files, they are used to convert the tensor values into labels.
* </ul>
* <li>scores tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>tensor of size {@code [1 x num_results]}, each value representing the score of
* the detected object.
* </ul>
* <li>Number of detection tensor ({@code kTfLiteFloat32}):
* <ul>
* <li>integer num_results as a tensor of size {@code [1]}.
* </ul>
* </ul>
* </ul>
*
* <p>An example of such model can be found on <a
* href="https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1">TensorFlow
* Hub.</a>.
*/
public final class ObjectDetector extends BaseVisionTaskApi {
private static final String TAG = ObjectDetector.class.getSimpleName();
private static final String IMAGE_IN_STREAM_NAME = "image_in";
private static final List<String> INPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME));
private static final List<String> OUTPUT_STREAMS =
Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out"));
private static final int DETECTIONS_OUT_STREAM_INDEX = 0;
private static final int IMAGE_OUT_STREAM_INDEX = 1;
private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.ObjectDetectorGraph";
/**
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
* ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param modelPath path to the detection model with metadata in the assets.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromFile(Context context, String modelPath) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build();
return createFromOptions(
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link ObjectDetector} instance from a model file and the default {@link
* ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param modelFile the detection model {@link File} instance.
* @throws IOException if an I/O error occurs when opening the tflite model file.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromFile(Context context, File modelFile) throws IOException {
try (ParcelFileDescriptor descriptor =
ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
BaseOptions baseOptions =
BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build();
return createFromOptions(
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
}
}
/**
* Creates an {@link ObjectDetector} instance from a model buffer and the default {@link
* ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection
* model.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) {
BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build();
return createFromOptions(
context, ObjectDetectorOptions.builder().setBaseOptions(baseOptions).build());
}
/**
* Creates an {@link ObjectDetector} instance from an {@link ObjectDetectorOptions}.
*
* @param context an Android {@link Context}.
* @param detectorOptions a {@link ObjectDetectorOptions} instance.
* @throws MediaPipeException if there is an error during {@link ObjectDetector} creation.
*/
public static ObjectDetector createFromOptions(
Context context, ObjectDetectorOptions detectorOptions) {
// TODO: Consolidate OutputHandler and TaskRunner.
OutputHandler<ObjectDetectionResult, Image> handler = new OutputHandler<>();
handler.setOutputPacketConverter(
new OutputHandler.OutputPacketConverter<ObjectDetectionResult, Image>() {
@Override
public ObjectDetectionResult convertToTaskResult(List<Packet> packets) {
return ObjectDetectionResult.create(
PacketGetter.getProtoVector(
packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()),
packets.get(DETECTIONS_OUT_STREAM_INDEX).getTimestamp());
}
@Override
public Image convertToTaskInput(List<Packet> packets) {
return new BitmapImageBuilder(
AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX)))
.build();
}
});
detectorOptions.resultListener().ifPresent(handler::setResultListener);
detectorOptions.errorListener().ifPresent(handler::setErrorListener);
TaskRunner runner =
TaskRunner.create(
context,
TaskInfo.<ObjectDetectorOptions>builder()
.setTaskGraphName(TASK_GRAPH_NAME)
.setInputStreams(INPUT_STREAMS)
.setOutputStreams(OUTPUT_STREAMS)
.setTaskOptions(detectorOptions)
.setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM)
.build(),
handler);
detectorOptions.errorListener().ifPresent(runner::setErrorListener);
return new ObjectDetector(runner, detectorOptions.runningMode());
}
/**
* Constructor to initialize an {@link ObjectDetector} from a {@link TaskRunner} and a {@link
* RunningMode}.
*
* @param taskRunner a {@link TaskRunner}.
* @param runningMode a mediapipe vision task {@link RunningMode}.
*/
private ObjectDetector(TaskRunner taskRunner, RunningMode runningMode) {
super(taskRunner, runningMode);
}
/**
* Performs object detection on the provided single image. Only use this method when the {@link
* ObjectDetector} is created with {@link RunningMode.IMAGE}.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detect(Image inputImage) {
return (ObjectDetectionResult) processImageData(IMAGE_IN_STREAM_NAME, inputImage);
}
/**
* Performs object detection on the provided video frame. Only use this method when the {@link
* ObjectDetector} is created with {@link RunningMode.VIDEO}.
*
* <p>It's required to provide the video frame's timestamp (in milliseconds). The input timestamps
* must be monotonically increasing.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public ObjectDetectionResult detectForVideo(Image inputImage, long inputTimestampMs) {
return (ObjectDetectionResult)
processVideoData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
}
/**
* Sends live image data to perform object detection, and the results will be available via the
* {@link ResultListener} provided in the {@link ObjectDetectorOptions}. Only use this method when
* the {@link ObjectDetector} is created with {@link RunningMode.LIVE_STREAM}.
*
* <p>It's required to provide a timestamp (in milliseconds) to indicate when the input image is
* sent to the object detector. The input timestamps must be monotonically increasing.
*
* <p>{@link ObjectDetector} supports the following color space types:
*
* <ul>
* <li>{@link Bitmap.Config.ARGB_8888}
* </ul>
*
* @param inputImage a MediaPipe {@link Image} object for processing.
* @param inputTimestampMs the input timestamp (in milliseconds).
* @throws MediaPipeException if there is an internal error.
*/
public void detectAsync(Image inputImage, long inputTimestampMs) {
sendLiveStreamData(IMAGE_IN_STREAM_NAME, inputImage, inputTimestampMs);
}
/** Options for setting up an {@link ObjectDetector}. */
@AutoValue
public abstract static class ObjectDetectorOptions extends TaskOptions {
/** Builder for {@link ObjectDetectorOptions}. */
@AutoValue.Builder
public abstract static class Builder {
/** Sets the base options for the object detector task. */
public abstract Builder setBaseOptions(BaseOptions value);
/**
* Sets the running mode for the object detector task. Default to the image mode. Object
* detector has three modes:
*
* <ul>
* <li>IMAGE: The mode for detecting objects on single image inputs.
* <li>VIDEO: The mode for detecting objects on the decoded frames of a video.
* <li>LIVE_STREAM: The mode for for detecting objects on a live stream of input data, such
* as from camera. In this mode, {@code setResultListener} must be called to set up a
* listener to receive the detection results asynchronously.
* </ul>
*/
public abstract Builder setRunningMode(RunningMode value);
/**
* Sets the locale to use for display names specified through the TFLite Model Metadata, if
* any. Defaults to English.
*/
public abstract Builder setDisplayNamesLocale(String value);
/**
* Sets the optional maximum number of top-scored detection results to return.
*
* <p>Overrides the ones provided in the model metadata. Results below this value are
* rejected.
*/
public abstract Builder setMaxResults(Integer value);
/**
* Sets the optional score threshold that overrides the one provided in the model metadata (if
* any). Results below this value are rejected.
*/
public abstract Builder setScoreThreshold(Float value);
/**
* Sets the optional allowlist of category names.
*
* <p>If non-empty, detection results whose category name is not in this set will be filtered
* out. Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryDenylist}.
*/
public abstract Builder setCategoryAllowlist(List<String> value);
/**
* Sets the optional denylist of category names.
*
* <p>If non-empty, detection results whose category name is in this set will be filtered out.
* Duplicate or unknown category names are ignored. Mutually exclusive with {@code
* categoryAllowlist}.
*/
public abstract Builder setCategoryDenylist(List<String> value);
/**
* Sets the result listener to receive the detection results asynchronously when the object
* detector is in the live stream mode.
*/
public abstract Builder setResultListener(ResultListener<ObjectDetectionResult, Image> value);
/** Sets an optional error listener. */
public abstract Builder setErrorListener(ErrorListener value);
abstract ObjectDetectorOptions autoBuild();
/**
* Validates and builds the {@link ObjectDetectorOptions} instance.
*
* @throws IllegalArgumentException if the result listener and the running mode are not
* properly configured. The result listener should only be set when the object detector is
* in the live stream mode.
*/
public final ObjectDetectorOptions build() {
ObjectDetectorOptions options = autoBuild();
if (options.runningMode() == RunningMode.LIVE_STREAM) {
if (!options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The object detector is in the live stream mode, a user-defined result listener"
+ " must be provided in ObjectDetectorOptions.");
}
} else if (options.resultListener().isPresent()) {
throw new IllegalArgumentException(
"The object detector is in the image or the video mode, a user-defined result"
+ " listener shouldn't be provided in ObjectDetectorOptions.");
}
return options;
}
}
abstract BaseOptions baseOptions();
abstract RunningMode runningMode();
abstract Optional<String> displayNamesLocale();
abstract Optional<Integer> maxResults();
abstract Optional<Float> scoreThreshold();
abstract List<String> categoryAllowlist();
abstract List<String> categoryDenylist();
abstract Optional<ResultListener<ObjectDetectionResult, Image>> resultListener();
abstract Optional<ErrorListener> errorListener();
public static Builder builder() {
return new AutoValue_ObjectDetector_ObjectDetectorOptions.Builder()
.setRunningMode(RunningMode.IMAGE)
.setCategoryAllowlist(Collections.emptyList())
.setCategoryDenylist(Collections.emptyList());
}
/** Converts a {@link ObjectDetectorOptions} to a {@link CalculatorOptions} protobuf message. */
@Override
public CalculatorOptions convertToCalculatorOptionsProto() {
BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder =
BaseOptionsProto.BaseOptions.newBuilder();
baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE);
baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions()));
ObjectDetectorOptionsProto.ObjectDetectorOptions.Builder taskOptionsBuilder =
ObjectDetectorOptionsProto.ObjectDetectorOptions.newBuilder()
.setBaseOptions(baseOptionsBuilder);
displayNamesLocale().ifPresent(taskOptionsBuilder::setDisplayNamesLocale);
maxResults().ifPresent(taskOptionsBuilder::setMaxResults);
scoreThreshold().ifPresent(taskOptionsBuilder::setScoreThreshold);
if (!categoryAllowlist().isEmpty()) {
taskOptionsBuilder.addAllCategoryAllowlist(categoryAllowlist());
}
if (!categoryDenylist().isEmpty()) {
taskOptionsBuilder.addAllCategoryDenylist(categoryDenylist());
}
return CalculatorOptions.newBuilder()
.setExtension(
ObjectDetectorOptionsProto.ObjectDetectorOptions.ext, taskOptionsBuilder.build())
.build();
}
}
}

View File

@ -11,3 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
android_library(
name = "test_utils",
srcs = ["TestUtils.java"],
deps = [
"//third_party/java/android_libs/guava_jdk5:io",
],
)

View File

@ -0,0 +1,83 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.core;
import android.content.Context;
import android.content.res.AssetManager;
import com.google.common.io.ByteStreams;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
/** Helper class for the Java test in MediaPipe Tasks. */
public final class TestUtils {
/**
* Loads the file and create a {@link File} object by reading a file from the asset directory.
* Simulates downloading or reading a file that's not precompiled with the app.
*
* @return a {@link File} object for the model.
*/
public static File loadFile(Context context, String fileName) {
File target = new File(context.getFilesDir(), fileName);
try (InputStream is = context.getAssets().open(fileName);
FileOutputStream os = new FileOutputStream(target)) {
ByteStreams.copy(is, os);
} catch (IOException e) {
throw new AssertionError("Failed to load model file at " + fileName, e);
}
return target;
}
/**
* Reads a file into a direct {@link ByteBuffer} object from the asset directory.
*
* @return a {@link ByteBuffer} object for the file.
*/
public static ByteBuffer loadToDirectByteBuffer(Context context, String fileName)
throws IOException {
AssetManager assetManager = context.getAssets();
InputStream inputStream = assetManager.open(fileName);
byte[] bytes = ByteStreams.toByteArray(inputStream);
ByteBuffer buffer = ByteBuffer.allocateDirect(bytes.length).order(ByteOrder.nativeOrder());
buffer.put(bytes);
return buffer;
}
/**
* Reads a file into a non-direct {@link ByteBuffer} object from the asset directory.
*
* @return a {@link ByteBuffer} object for the file.
*/
public static ByteBuffer loadToNonDirectByteBuffer(Context context, String fileName)
throws IOException {
AssetManager assetManager = context.getAssets();
InputStream inputStream = assetManager.open(fileName);
byte[] bytes = ByteStreams.toByteArray(inputStream);
return ByteBuffer.wrap(bytes);
}
public enum ByteBufferType {
DIRECT,
BACK_UP_ARRAY,
OTHER // Non-direct ByteBuffer without a back-up array.
}
private TestUtils() {}
}

View File

@ -12,4 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
# TODO: Enable this in OSS # TODO: Enable this in OSS

View File

@ -0,0 +1,456 @@
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.mediapipe.tasks.vision.objectdetector;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
import android.content.res.AssetManager;
import android.graphics.BitmapFactory;
import android.graphics.RectF;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.mediapipe.framework.MediaPipeException;
import com.google.mediapipe.framework.image.BitmapImageBuilder;
import com.google.mediapipe.framework.image.Image;
import com.google.mediapipe.tasks.components.containers.Category;
import com.google.mediapipe.tasks.components.containers.Detection;
import com.google.mediapipe.tasks.core.BaseOptions;
import com.google.mediapipe.tasks.core.TestUtils;
import com.google.mediapipe.tasks.vision.core.RunningMode;
import com.google.mediapipe.tasks.vision.objectdetector.ObjectDetector.ObjectDetectorOptions;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.junit.runners.Suite.SuiteClasses;
/** Test for {@link ObjectDetector}. */
@RunWith(Suite.class)
@SuiteClasses({ObjectDetectorTest.General.class, ObjectDetectorTest.RunningModeTest.class})
public class ObjectDetectorTest {
private static final String MODEL_FILE = "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite";
private static final String CAT_AND_DOG_IMAGE = "cats_and_dogs.jpg";
private static final int IMAGE_WIDTH = 1200;
private static final int IMAGE_HEIGHT = 600;
private static final float CAT_SCORE = 0.69f;
private static final RectF catBoundingBox = new RectF(611, 164, 986, 596);
// TODO: Figure out why android_x86 and android_arm tests have slightly different
// scores (0.6875 vs 0.69921875).
private static final float SCORE_DIFF_TOLERANCE = 0.01f;
private static final float PIXEL_DIFF_TOLERANCE = 5.0f;
@RunWith(AndroidJUnit4.class)
public static final class General extends ObjectDetectorTest {
@Test
public void detect_successWithValidModels() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void detect_successWithNoOptions() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(ApplicationProvider.getApplicationContext(), MODEL_FILE);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@Test
public void detect_succeedsWithMaxResultsOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setMaxResults(8)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// results should have 8 detected objects because maxResults was set to 8.
assertThat(results.detections()).hasSize(8);
}
@Test
public void detect_succeedsWithScoreThresholdOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setScoreThreshold(0.68f)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// The score threshold should block all other other objects, except cat.
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void detect_succeedsWithAllowListOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("cat"))
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the allowlist, results should only contain cat, and there are 6 detected
// bounding boxes of cats in CAT_AND_DOG_IMAGE.
assertThat(results.detections()).hasSize(5);
}
@Test
public void detect_succeedsWithDenyListOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setCategoryDenylist(Arrays.asList("cat"))
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Because of the denylist, the highest result is not cat anymore.
assertThat(results.detections().get(0).categories().get(0).categoryName())
.isNotEqualTo("cat");
}
@Test
public void detect_succeedsWithModelFileObject() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromFile(
ApplicationProvider.getApplicationContext(),
TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@Test
public void detect_succeedsWithModelBuffer() throws Exception {
ObjectDetector objectDetector =
ObjectDetector.createFromBuffer(
ApplicationProvider.getApplicationContext(),
TestUtils.loadToDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE));
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
// Check if the object with the highest score is cat.
assertIsCat(results.detections().get(0).categories().get(0), CAT_SCORE);
}
@Test
public void detect_succeedsWithModelBufferAndOptions() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(
BaseOptions.builder()
.setModelAssetBuffer(
TestUtils.loadToDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE))
.build())
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void create_failsWithMissingModel() throws Exception {
String nonexistentFile = "/path/to/non/existent/file";
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
ObjectDetector.createFromFile(
ApplicationProvider.getApplicationContext(), nonexistentFile));
assertThat(exception).hasMessageThat().contains(nonexistentFile);
}
@Test
public void create_failsWithInvalidModelBuffer() throws Exception {
// Create a non-direct model ByteBuffer.
ByteBuffer modelBuffer =
TestUtils.loadToNonDirectByteBuffer(
ApplicationProvider.getApplicationContext(), MODEL_FILE);
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ObjectDetector.createFromBuffer(
ApplicationProvider.getApplicationContext(), modelBuffer));
assertThat(exception)
.hasMessageThat()
.contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer.");
}
@Test
public void detect_failsWithBothAllowAndDenyListOption() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setCategoryAllowlist(Arrays.asList("cat"))
.setCategoryDenylist(Arrays.asList("dog"))
.build();
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() ->
ObjectDetector.createFromOptions(
ApplicationProvider.getApplicationContext(), options));
assertThat(exception)
.hasMessageThat()
.contains("`category_allowlist` and `category_denylist` are mutually exclusive options.");
}
// TODO: Implement detect_succeedsWithFloatImages, detect_succeedsWithOrientation,
// detect_succeedsWithNumThreads, detect_successWithNumThreadsFromBaseOptions,
// detect_failsWithInvalidNegativeNumThreads, detect_failsWithInvalidNumThreadsAsZero.
}
@RunWith(AndroidJUnit4.class)
public static final class RunningModeTest extends ObjectDetectorTest {
@Test
public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception {
for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(mode)
.setResultListener((objectDetectionResult, inputImage) -> {})
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener shouldn't be provided");
}
}
@Test
public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception {
IllegalArgumentException exception =
assertThrows(
IllegalArgumentException.class,
() ->
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.build());
assertThat(exception)
.hasMessageThat()
.contains("a user-defined result listener must be provided");
}
@Test
public void detect_failsWithCallingWrongApiInImageMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void detect_failsWithCallingWrongApiInVideoMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectAsync(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode");
}
@Test
public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener((objectDetectionResult, inputImage) -> {})
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
MediaPipeException exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE)));
assertThat(exception).hasMessageThat().contains("not initialized with the image mode");
exception =
assertThrows(
MediaPipeException.class,
() -> objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), 0));
assertThat(exception).hasMessageThat().contains("not initialized with the video mode");
}
@Test
public void detect_successWithImageMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.IMAGE)
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
ObjectDetectionResult results = objectDetector.detect(getImageFromAsset(CAT_AND_DOG_IMAGE));
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
@Test
public void detect_successWithVideoMode() throws Exception {
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.VIDEO)
.setMaxResults(1)
.build();
ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options);
for (int i = 0; i < 3; i++) {
ObjectDetectionResult results =
objectDetector.detectForVideo(getImageFromAsset(CAT_AND_DOG_IMAGE), i);
assertContainsOnlyCat(results, catBoundingBox, CAT_SCORE);
}
}
@Test
public void detect_failsWithOutOfOrderInputTimestamps() throws Exception {
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
.build();
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
objectDetector.detectAsync(image, 1);
MediaPipeException exception =
assertThrows(MediaPipeException.class, () -> objectDetector.detectAsync(image, 0));
assertThat(exception)
.hasMessageThat()
.contains("having a smaller timestamp than the processed timestamp");
}
}
@Test
public void detect_successWithLiveSteamMode() throws Exception {
Image image = getImageFromAsset(CAT_AND_DOG_IMAGE);
ObjectDetectorOptions options =
ObjectDetectorOptions.builder()
.setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build())
.setRunningMode(RunningMode.LIVE_STREAM)
.setResultListener(
(objectDetectionResult, inputImage) -> {
assertContainsOnlyCat(objectDetectionResult, catBoundingBox, CAT_SCORE);
assertImageSizeIsExpected(inputImage);
})
.setMaxResults(1)
.build();
try (ObjectDetector objectDetector =
ObjectDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) {
for (int i = 0; i < 3; i++) {
objectDetector.detectAsync(image, i);
}
}
}
}
private static Image getImageFromAsset(String filePath) throws Exception {
AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets();
InputStream istr = assetManager.open(filePath);
return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build();
}
// Checks if results has one and only detection result, which is a cat.
private static void assertContainsOnlyCat(
ObjectDetectionResult result, RectF expectedBoundingBox, float expectedScore) {
assertThat(result.detections()).hasSize(1);
Detection catResult = result.detections().get(0);
assertApproximatelyEqualBoundingBoxes(catResult.boundingBox(), expectedBoundingBox);
// We only support one category for each detected object at this point.
assertIsCat(catResult.categories().get(0), expectedScore);
}
private static void assertIsCat(Category category, float expectedScore) {
assertThat(category.categoryName()).isEqualTo("cat");
// coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite does not support label locale.
assertThat(category.displayName()).isEmpty();
assertThat((double) category.score()).isWithin(SCORE_DIFF_TOLERANCE).of(expectedScore);
assertThat(category.index()).isEqualTo(-1);
}
private static void assertApproximatelyEqualBoundingBoxes(
RectF boundingBox1, RectF boundingBox2) {
assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left);
assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top);
assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right);
assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom);
}
private static void assertImageSizeIsExpected(Image inputImage) {
assertThat(inputImage).isNotNull();
assertThat(inputImage.getWidth()).isEqualTo(IMAGE_WIDTH);
assertThat(inputImage.getHeight()).isEqualTo(IMAGE_HEIGHT);
}
}

View File

@ -1,4 +1,4 @@
load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library", "flatbuffer_py_library")
package( package(
default_visibility = [ default_visibility = [
@ -14,3 +14,13 @@ flatbuffer_cc_library(
name = "metadata_schema_cc", name = "metadata_schema_cc",
srcs = ["metadata_schema.fbs"], srcs = ["metadata_schema.fbs"],
) )
flatbuffer_py_library(
name = "schema_py",
srcs = ["@org_tensorflow//tensorflow/lite/schema:schema.fbs"],
)
flatbuffer_py_library(
name = "metadata_schema_py",
srcs = ["metadata_schema.fbs"],
)

View File

@ -31,7 +31,7 @@ py_library(
name = "category", name = "category",
srcs = ["category.py"], srcs = ["category.py"],
deps = [ deps = [
"//mediapipe/tasks/cc/components/containers:category_py_pb2", "//mediapipe/tasks/cc/components/containers/proto:category_py_pb2",
"//mediapipe/tasks/python/core:optional_dependencies", "//mediapipe/tasks/python/core:optional_dependencies",
], ],
) )

View File

@ -16,7 +16,7 @@
import dataclasses import dataclasses
from typing import Any from typing import Any
from mediapipe.tasks.cc.components.containers import category_pb2 from mediapipe.tasks.cc.components.containers.proto import category_pb2
from mediapipe.tasks.python.core.optional_dependencies import doc_controls from mediapipe.tasks.python.core.optional_dependencies import doc_controls
_CategoryProto = category_pb2.Category _CategoryProto = category_pb2.Category

View File

@ -27,6 +27,7 @@ pybind_library(
"//mediapipe/framework/api2:builder", "//mediapipe/framework/api2:builder",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/python/pybind:util", "//mediapipe/python/pybind:util",
"//mediapipe/tasks/cc/core:mediapipe_builtin_op_resolver",
"//mediapipe/tasks/cc/core:task_runner", "//mediapipe/tasks/cc/core:task_runner",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver", "@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster",

View File

@ -16,6 +16,7 @@
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/python/pybind/util.h" #include "mediapipe/python/pybind/util.h"
#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h"
#include "mediapipe/tasks/cc/core/task_runner.h" #include "mediapipe/tasks/cc/core/task_runner.h"
#include "pybind11/stl.h" #include "pybind11/stl.h"
#include "pybind11_protobuf/native_proto_caster.h" #include "pybind11_protobuf/native_proto_caster.h"
@ -75,7 +76,7 @@ mode) or not (synchronous mode).)doc");
} }
auto task_runner = TaskRunner::Create( auto task_runner = TaskRunner::Create(
std::move(graph_config), std::move(graph_config),
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(), absl::make_unique<core::MediaPipeBuiltinOpResolver>(),
std::move(callback)); std::move(callback));
RaisePyErrorIfNotOk(task_runner.status()); RaisePyErrorIfNotOk(task_runner.status());
return std::move(*task_runner); return std::move(*task_runner);

View File

@ -0,0 +1,38 @@
load("//mediapipe/tasks/metadata:build_defs.bzl", "stamp_metadata_parser_version")
package(
licenses = ["notice"], # Apache 2.0
)
stamp_metadata_parser_version(
name = "metadata_parser_py",
srcs = ["metadata_parser.py.template"],
outs = ["metadata_parser.py"],
)
py_library(
name = "metadata",
srcs = [
"metadata.py",
":metadata_parser_py",
],
data = ["//mediapipe/tasks/metadata:metadata_schema.fbs"],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
"//mediapipe/tasks/cc/metadata/python:_pywrap_metadata_version",
"//mediapipe/tasks/metadata:metadata_schema_py",
"//mediapipe/tasks/metadata:schema_py",
"//mediapipe/tasks/python/metadata/flatbuffers_lib:_pywrap_flatbuffers",
"@flatbuffers//:runtime_py",
],
)
py_binary(
name = "metadata_displayer_cli",
srcs = ["metadata_displayer_cli.py"],
visibility = [
"//visibility:public",
],
deps = [":metadata"],
)

View File

@ -0,0 +1,13 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,20 @@
load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension")
package(
default_visibility = ["//mediapipe/tasks:internal"],
licenses = ["notice"], # Apache 2.0
)
pybind_extension(
name = "_pywrap_flatbuffers",
srcs = [
"flatbuffers_lib.cc",
],
features = ["-use_header_modules"],
module_name = "_pywrap_flatbuffers",
deps = [
"@flatbuffers",
"@local_config_python//:python_headers",
"@pybind11",
],
)

View File

@ -0,0 +1,59 @@
/* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/idl.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "pybind11/stl.h"
namespace tflite {
namespace support {
PYBIND11_MODULE(_pywrap_flatbuffers, m) {
pybind11::class_<flatbuffers::IDLOptions>(m, "IDLOptions")
.def(pybind11::init<>())
.def_readwrite("strict_json", &flatbuffers::IDLOptions::strict_json);
pybind11::class_<flatbuffers::Parser>(m, "Parser")
.def(pybind11::init<const flatbuffers::IDLOptions&>())
.def("parse",
[](flatbuffers::Parser* self, const std::string& source) {
return self->Parse(source.c_str());
})
.def_readonly("builder", &flatbuffers::Parser::builder_)
.def_readonly("error", &flatbuffers::Parser::error_);
pybind11::class_<flatbuffers::FlatBufferBuilder>(m, "FlatBufferBuilder")
.def("clear", &flatbuffers::FlatBufferBuilder::Clear)
.def("push_flat_buffer", [](flatbuffers::FlatBufferBuilder* self,
const std::string& contents) {
self->PushFlatBuffer(reinterpret_cast<const uint8_t*>(contents.c_str()),
contents.length());
});
m.def("generate_text_file", &flatbuffers::GenerateTextFile);
m.def(
"generate_text",
[](const flatbuffers::Parser& parser,
const std::string& buffer) -> std::string {
std::string text;
if (!flatbuffers::GenerateText(
parser, reinterpret_cast<const void*>(buffer.c_str()), &text)) {
return "";
}
return text;
});
}
} // namespace support
} // namespace tflite

View File

@ -0,0 +1,865 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Lite metadata tools."""
import copy
import inspect
import io
import os
import shutil
import sys
import tempfile
import warnings
import zipfile
import flatbuffers
from mediapipe.tasks.cc.metadata.python import _pywrap_metadata_version
from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb
from mediapipe.tasks.metadata import schema_py_generated as _schema_fb
from mediapipe.tasks.python.metadata.flatbuffers_lib import _pywrap_flatbuffers
try:
# If exists, optionally use TensorFlow to open and check files. Used to
# support more than local file systems.
# In pip requirements, we doesn't necessarily need tensorflow as a dep.
import tensorflow as tf
_open_file = tf.io.gfile.GFile
_exists_file = tf.io.gfile.exists
except ImportError as e:
# If TensorFlow package doesn't exist, fall back to original open and exists.
_open_file = open
_exists_file = os.path.exists
def _maybe_open_as_binary(filename, mode):
"""Maybe open the binary file, and returns a file-like."""
if hasattr(filename, "read"): # A file-like has read().
return filename
openmode = mode if "b" in mode else mode + "b" # Add binary explicitly.
return _open_file(filename, openmode)
def _open_as_zipfile(filename, mode="r"):
"""Open file as a zipfile.
Args:
filename: str or file-like or path-like, to the zipfile.
mode: str, common file mode for zip.
(See: https://docs.python.org/3/library/zipfile.html)
Returns:
A ZipFile object.
"""
file_like = _maybe_open_as_binary(filename, mode)
return zipfile.ZipFile(file_like, mode)
def _is_zipfile(filename):
"""Checks whether it is a zipfile."""
with _maybe_open_as_binary(filename, "r") as f:
return zipfile.is_zipfile(f)
def get_path_to_datafile(path):
"""Gets the path to the specified file in the data dependencies.
The path is relative to the file calling the function.
It's a simple replacement of
"tensorflow.python.platform.resource_loader.get_path_to_datafile".
Args:
path: a string resource path relative to the calling file.
Returns:
The path to the specified file present in the data attribute of py_test
or py_binary.
"""
data_files_path = os.path.dirname(inspect.getfile(sys._getframe(1))) # pylint: disable=protected-access
return os.path.join(data_files_path, path)
_FLATC_TFLITE_METADATA_SCHEMA_FILE = get_path_to_datafile(
"../../metadata/metadata_schema.fbs")
# TODO: add delete method for associated files.
class MetadataPopulator(object):
"""Packs metadata and associated files into TensorFlow Lite model file.
MetadataPopulator can be used to populate metadata and model associated files
into a model file or a model buffer (in bytearray). It can also help to
inspect list of files that have been packed into the model or are supposed to
be packed into the model.
The metadata file (or buffer) should be generated based on the metadata
schema:
third_party/tensorflow/lite/schema/metadata_schema.fbs
Example usage:
Populate matadata and label file into an image classifier model.
First, based on metadata_schema.fbs, generate the metadata for this image
classifer model using Flatbuffers API. Attach the label file onto the ouput
tensor (the tensor of probabilities) in the metadata.
Then, pack the metadata and label file into the model as follows.
```python
# Populating a metadata file (or a metadta buffer) and associated files to
a model file:
populator = MetadataPopulator.with_model_file(model_file)
# For metadata buffer (bytearray read from the metadata file), use:
# populator.load_metadata_buffer(metadata_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
# For associated file buffer (bytearray read from the file), use:
# populator.load_associated_file_buffers({"label.txt": b"file content"})
populator.populate()
# Populating a metadata file (or a metadata buffer) and associated files to
a model buffer:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator.load_metadata_file(metadata_file)
populator.load_associated_files([label.txt])
populator.populate()
# Writing the updated model buffer into a file.
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
# Transferring metadata and associated files from another TFLite model:
populator = MetadataPopulator.with_model_buffer(model_buf)
populator_dst.load_metadata_and_associated_files(src_model_buf)
populator_dst.populate()
updated_model_buf = populator.get_model_buffer()
with open("updated_model.tflite", "wb") as f:
f.write(updated_model_buf)
```
Note that existing metadata buffer (if applied) will be overridden by the new
metadata buffer.
"""
# As Zip API is used to concatenate associated files after tflite model file,
# the populating operation is developed based on a model file. For in-memory
# model buffer, we create a tempfile to serve the populating operation.
# Creating the deleting such a tempfile is handled by the class,
# _MetadataPopulatorWithBuffer.
METADATA_FIELD_NAME = "TFLITE_METADATA"
TFLITE_FILE_IDENTIFIER = b"TFL3"
METADATA_FILE_IDENTIFIER = b"M001"
def __init__(self, model_file):
"""Constructor for MetadataPopulator.
Args:
model_file: valid path to a TensorFlow Lite model file.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
_assert_model_file_identifier(model_file)
self._model_file = model_file
self._metadata_buf = None
# _associated_files is a dict of file name and file buffer.
self._associated_files = {}
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataPopulator object that populates data to a model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataPopulator object.
Raises:
IOError: File not found.
ValueError: the model does not have the expected flatbuffer identifer.
"""
return cls(model_file)
# TODO: investigate if type check can be applied to model_buf for
# FB.
@classmethod
def with_model_buffer(cls, model_buf):
"""Creates a MetadataPopulator object that populates data to a model buffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Returns:
A MetadataPopulator(_MetadataPopulatorWithBuffer) object.
Raises:
ValueError: the model does not have the expected flatbuffer identifer.
"""
return _MetadataPopulatorWithBuffer(model_buf)
def get_model_buffer(self):
"""Gets the buffer of the model with packed metadata and associated files.
Returns:
Model buffer (in bytearray).
"""
with _open_file(self._model_file, "rb") as f:
return f.read()
def get_packed_associated_file_list(self):
"""Gets a list of associated files packed to the model file.
Returns:
List of packed associated files.
"""
if not _is_zipfile(self._model_file):
return []
with _open_as_zipfile(self._model_file, "r") as zf:
return zf.namelist()
def get_recorded_associated_file_list(self):
"""Gets a list of associated files recorded in metadata of the model file.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Returns:
List of recorded associated files.
"""
if not self._metadata_buf:
return []
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(
self._metadata_buf, 0))
return [
file.name.decode("utf-8")
for file in self._get_recorded_associated_file_object_list(metadata)
]
def load_associated_file_buffers(self, associated_files):
"""Loads the associated file buffers (in bytearray) to be populated.
Args:
associated_files: a dictionary of associated file names and corresponding
file buffers, such as {"file.txt": b"file content"}. If pass in file
paths for the file name, only the basename will be populated.
"""
self._associated_files.update({
os.path.basename(name): buffers
for name, buffers in associated_files.items()
})
def load_associated_files(self, associated_files):
"""Loads associated files that to be concatenated after the model file.
Args:
associated_files: list of file paths.
Raises:
IOError:
File not found.
"""
for af_name in associated_files:
_assert_file_exist(af_name)
with _open_file(af_name, "rb") as af:
self.load_associated_file_buffers({af_name: af.read()})
def load_metadata_buffer(self, metadata_buf):
"""Loads the metadata buffer (in bytearray) to be populated.
Args:
metadata_buf: metadata buffer (in bytearray) to be populated.
Raises:
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
if not metadata_buf:
raise ValueError("The metadata to be populated is empty.")
self._validate_metadata(metadata_buf)
# Gets the minimum metadata parser version of the metadata_buf.
min_version = _pywrap_metadata_version.GetMinimumMetadataParserVersion(
bytes(metadata_buf))
# Inserts in the minimum metadata parser version into the metadata_buf.
metadata = _metadata_fb.ModelMetadataT.InitFromObj(
_metadata_fb.ModelMetadata.GetRootAsModelMetadata(metadata_buf, 0))
metadata.minParserVersion = min_version
# Remove local file directory in the `name` field of `AssociatedFileT`, and
# make it consistent with the name of the actual file packed in the model.
self._use_basename_for_associated_files_in_metadata(metadata)
b = flatbuffers.Builder(0)
b.Finish(metadata.Pack(b), self.METADATA_FILE_IDENTIFIER)
metadata_buf_with_version = b.Output()
self._metadata_buf = metadata_buf_with_version
def load_metadata_file(self, metadata_file):
"""Loads the metadata file to be populated.
Args:
metadata_file: path to the metadata file to be populated.
Raises:
IOError: File not found.
ValueError: The metadata to be populated is empty.
ValueError: The metadata does not have the expected flatbuffer identifer.
ValueError: Cannot get minimum metadata parser version.
ValueError: The number of SubgraphMetadata is not 1.
ValueError: The number of input/output tensors does not match the number
of input/output tensor metadata.
"""
_assert_file_exist(metadata_file)
with _open_file(metadata_file, "rb") as f:
metadata_buf = f.read()
self.load_metadata_buffer(bytearray(metadata_buf))
def load_metadata_and_associated_files(self, src_model_buf):
"""Loads the metadata and associated files from another model buffer.
Args:
src_model_buf: source model buffer (in bytearray) with metadata and
associated files.
"""
# Load the model metadata from src_model_buf if exist.
metadata_buffer = get_metadata_buffer(src_model_buf)
if metadata_buffer:
self.load_metadata_buffer(metadata_buffer)
# Load the associated files from src_model_buf if exist.
if _is_zipfile(io.BytesIO(src_model_buf)):
with _open_as_zipfile(io.BytesIO(src_model_buf)) as zf:
self.load_associated_file_buffers(
{f: zf.read(f) for f in zf.namelist()})
def populate(self):
"""Populates loaded metadata and associated files into the model file."""
self._assert_validate()
self._populate_metadata_buffer()
self._populate_associated_files()
def _assert_validate(self):
"""Validates the metadata and associated files to be populated.
Raises:
ValueError:
File is recorded in the metadata, but is not going to be populated.
File has already been packed.
"""
# Gets files that are recorded in metadata.
recorded_files = self.get_recorded_associated_file_list()
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
# Gets the file name of those associated files to be populated.
to_be_populated_files = self._associated_files.keys()
# Checks all files recorded in the metadata will be populated.
for rf in recorded_files:
if rf not in to_be_populated_files and rf not in packed_files:
raise ValueError("File, '{0}', is recorded in the metadata, but has "
"not been loaded into the populator.".format(rf))
for f in to_be_populated_files:
if f in packed_files:
raise ValueError("File, '{0}', has already been packed.".format(f))
if f not in recorded_files:
warnings.warn(
"File, '{0}', does not exist in the metadata. But packing it to "
"tflite model is still allowed.".format(f))
def _copy_archived_files(self, src_zip, file_list, dst_zip):
"""Copy archieved files in file_list from src_zip ro dst_zip."""
if not _is_zipfile(src_zip):
raise ValueError("File, '{0}', is not a zipfile.".format(src_zip))
with _open_as_zipfile(src_zip, "r") as src_zf, \
_open_as_zipfile(dst_zip, "a") as dst_zf:
src_list = src_zf.namelist()
for f in file_list:
if f not in src_list:
raise ValueError(
"File, '{0}', does not exist in the zipfile, {1}.".format(
f, src_zip))
file_buffer = src_zf.read(f)
dst_zf.writestr(f, file_buffer)
def _get_associated_files_from_process_units(self, table, field_name):
"""Gets the files that are attached the process units field of a table.
Args:
table: a Flatbuffers table object that contains fields of an array of
ProcessUnit, such as TensorMetadata and SubGraphMetadata.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"ProcessUnits". If the table is SubGraphMetadata, field_name can be
either "InputProcessUnits" or "OutputProcessUnits".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
file_list = []
process_units = getattr(table, field_name)
# If the process_units field is not populated, it will be None. Use an
# empty list to skip the check.
for process_unit in process_units or []:
options = process_unit.options
if isinstance(options, (_metadata_fb.BertTokenizerOptionsT,
_metadata_fb.RegexTokenizerOptionsT)):
file_list += self._get_associated_files_from_table(options, "vocabFile")
elif isinstance(options, _metadata_fb.SentencePieceTokenizerOptionsT):
file_list += self._get_associated_files_from_table(
options, "sentencePieceModel")
file_list += self._get_associated_files_from_table(options, "vocabFile")
return file_list
def _get_associated_files_from_table(self, table, field_name):
"""Gets the associated files that are attached a table directly.
Args:
table: a Flatbuffers table object that contains fields of an array of
AssociatedFile, such as TensorMetadata and BertTokenizerOptions.
field_name: the name of the field in the table that represents an array of
ProcessUnit. If the table is TensorMetadata, field_name can be
"AssociatedFiles". If the table is BertTokenizerOptions, field_name can
be "VocabFile".
Returns:
A list of AssociatedFileT objects.
"""
if table is None:
return []
# If the associated file field is not populated,
# `getattr(table, field_name)` will be None. Return an empty list.
return getattr(table, field_name) or []
def _get_recorded_associated_file_object_list(self, metadata):
"""Gets a list of AssociatedFileT objects recorded in the metadata.
Associated files may be attached to a model, a subgraph, or an input/output
tensor.
Args:
metadata: the ModelMetadataT object.
Returns:
List of recorded AssociatedFileT objects.
"""
recorded_files = []
# Add associated files attached to ModelMetadata.
recorded_files += self._get_associated_files_from_table(
metadata, "associatedFiles")
# Add associated files attached to each SubgraphMetadata.
for subgraph in metadata.subgraphMetadata or []:
recorded_files += self._get_associated_files_from_table(
subgraph, "associatedFiles")
# Add associated files attached to each input tensor.
for tensor_metadata in subgraph.inputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to each output tensor.
for tensor_metadata in subgraph.outputTensorMetadata or []:
recorded_files += self._get_associated_files_from_table(
tensor_metadata, "associatedFiles")
recorded_files += self._get_associated_files_from_process_units(
tensor_metadata, "processUnits")
# Add associated files attached to the input_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "inputProcessUnits")
# Add associated files attached to the output_process_units.
recorded_files += self._get_associated_files_from_process_units(
subgraph, "outputProcessUnits")
return recorded_files
def _populate_associated_files(self):
"""Concatenates associated files after TensorFlow Lite model file.
If the MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
"""
# Opens up the model file in "appending" mode.
# If self._model_file already has pack files, zipfile will concatenate
# addition files after self._model_file. For example, suppose we have
# self._model_file = old_tflite_file | label1.txt | label2.txt
# Then after trigger populate() to add label3.txt, self._model_file becomes
# self._model_file = old_tflite_file | label1.txt | label2.txt | label3.txt
with tempfile.SpooledTemporaryFile() as temp:
# (1) Copy content from model file of to temp file.
with _open_file(self._model_file, "rb") as f:
shutil.copyfileobj(f, temp)
# (2) Append of to a temp file as a zip.
with _open_as_zipfile(temp, "a") as zf:
for file_name, file_buffer in self._associated_files.items():
zf.writestr(file_name, file_buffer)
# (3) Copy temp file to model file.
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
def _populate_metadata_buffer(self):
"""Populates the metadata buffer (in bytearray) into the model file.
Inserts metadata_buf into the metadata field of schema.Model. If the
MetadataPopulator object is created using the method,
with_model_file(model_file), the model file will be updated.
Existing metadata buffer (if applied) will be overridden by the new metadata
buffer.
"""
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.ModelT.InitFromObj(
_schema_fb.Model.GetRootAsModel(model_buf, 0))
buffer_field = _schema_fb.BufferT()
buffer_field.data = self._metadata_buf
is_populated = False
if not model.metadata:
model.metadata = []
else:
# Check if metadata has already been populated.
for meta in model.metadata:
if meta.name.decode("utf-8") == self.METADATA_FIELD_NAME:
is_populated = True
model.buffers[meta.buffer] = buffer_field
if not is_populated:
if not model.buffers:
model.buffers = []
model.buffers.append(buffer_field)
# Creates a new metadata field.
metadata_field = _schema_fb.MetadataT()
metadata_field.name = self.METADATA_FIELD_NAME
metadata_field.buffer = len(model.buffers) - 1
model.metadata.append(metadata_field)
# Packs model back to a flatbuffer binaray file.
b = flatbuffers.Builder(0)
b.Finish(model.Pack(b), self.TFLITE_FILE_IDENTIFIER)
model_buf = b.Output()
# Saves the updated model buffer to model file.
# Gets files that have been packed to self._model_file.
packed_files = self.get_packed_associated_file_list()
if packed_files:
# Writes the updated model buffer and associated files into a new model
# file (in memory). Then overwrites the original model file.
with tempfile.SpooledTemporaryFile() as temp:
temp.write(model_buf)
self._copy_archived_files(self._model_file, packed_files, temp)
temp.seek(0)
with _open_file(self._model_file, "wb") as f:
shutil.copyfileobj(temp, f)
else:
with _open_file(self._model_file, "wb") as f:
f.write(model_buf)
def _use_basename_for_associated_files_in_metadata(self, metadata):
"""Removes any associated file local directory (if exists)."""
for file in self._get_recorded_associated_file_object_list(metadata):
file.name = os.path.basename(file.name)
def _validate_metadata(self, metadata_buf):
"""Validates the metadata to be populated."""
_assert_metadata_buffer_identifier(metadata_buf)
# Verify the number of SubgraphMetadata is exactly one.
# TFLite currently only support one subgraph.
model_meta = _metadata_fb.ModelMetadata.GetRootAsModelMetadata(
metadata_buf, 0)
if model_meta.SubgraphMetadataLength() != 1:
raise ValueError("The number of SubgraphMetadata should be exactly one, "
"but got {0}.".format(
model_meta.SubgraphMetadataLength()))
# Verify if the number of tensor metadata matches the number of tensors.
with _open_file(self._model_file, "rb") as f:
model_buf = f.read()
model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
num_input_tensors = model.Subgraphs(0).InputsLength()
num_input_meta = model_meta.SubgraphMetadata(0).InputTensorMetadataLength()
if num_input_tensors != num_input_meta:
raise ValueError(
"The number of input tensors ({0}) should match the number of "
"input tensor metadata ({1})".format(num_input_tensors,
num_input_meta))
num_output_tensors = model.Subgraphs(0).OutputsLength()
num_output_meta = model_meta.SubgraphMetadata(
0).OutputTensorMetadataLength()
if num_output_tensors != num_output_meta:
raise ValueError(
"The number of output tensors ({0}) should match the number of "
"output tensor metadata ({1})".format(num_output_tensors,
num_output_meta))
class _MetadataPopulatorWithBuffer(MetadataPopulator):
"""Subclass of MetadtaPopulator that populates metadata to a model buffer.
This class is used to populate metadata into a in-memory model buffer. As we
use Zip API to concatenate associated files after tflite model file, the
populating operation is developed based on a model file. For in-memory model
buffer, we create a tempfile to serve the populating operation. This class is
then used to generate this tempfile, and delete the file when the
MetadataPopulator object is deleted.
"""
def __init__(self, model_buf):
"""Constructor for _MetadataPopulatorWithBuffer.
Args:
model_buf: TensorFlow Lite model buffer in bytearray.
Raises:
ValueError: model_buf is empty.
ValueError: model_buf does not have the expected flatbuffer identifer.
"""
if not model_buf:
raise ValueError("model_buf cannot be empty.")
with tempfile.NamedTemporaryFile() as temp:
model_file = temp.name
with _open_file(model_file, "wb") as f:
f.write(model_buf)
super().__init__(model_file)
def __del__(self):
"""Destructor of _MetadataPopulatorWithBuffer.
Deletes the tempfile.
"""
if os.path.exists(self._model_file):
os.remove(self._model_file)
class MetadataDisplayer(object):
"""Displays metadata and associated file info in human-readable format."""
def __init__(self, model_buffer, metadata_buffer, associated_file_list):
"""Constructor for MetadataDisplayer.
Args:
model_buffer: valid buffer of the model file.
metadata_buffer: valid buffer of the metadata file.
associated_file_list: list of associate files in the model file.
"""
_assert_model_buffer_identifier(model_buffer)
_assert_metadata_buffer_identifier(metadata_buffer)
self._model_buffer = model_buffer
self._metadata_buffer = metadata_buffer
self._associated_file_list = associated_file_list
@classmethod
def with_model_file(cls, model_file):
"""Creates a MetadataDisplayer object for the model file.
Args:
model_file: valid path to a TensorFlow Lite model file.
Returns:
MetadataDisplayer object.
Raises:
IOError: File not found.
ValueError: The model does not have metadata.
"""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
return cls.with_model_buffer(f.read())
@classmethod
def with_model_buffer(cls, model_buffer):
"""Creates a MetadataDisplayer object for a file buffer.
Args:
model_buffer: TensorFlow Lite model buffer in bytearray.
Returns:
MetadataDisplayer object.
"""
if not model_buffer:
raise ValueError("model_buffer cannot be empty.")
metadata_buffer = get_metadata_buffer(model_buffer)
if not metadata_buffer:
raise ValueError("The model does not have metadata.")
associated_file_list = cls._parse_packed_associted_file_list(model_buffer)
return cls(model_buffer, metadata_buffer, associated_file_list)
def get_associated_file_buffer(self, filename):
"""Get the specified associated file content in bytearray.
Args:
filename: name of the file to be extracted.
Returns:
The file content in bytearray.
Raises:
ValueError: if the file does not exist in the model.
"""
if filename not in self._associated_file_list:
raise ValueError(
"The file, {}, does not exist in the model.".format(filename))
with _open_as_zipfile(io.BytesIO(self._model_buffer)) as zf:
return zf.read(filename)
def get_metadata_buffer(self):
"""Get the metadata buffer in bytearray out from the model."""
return copy.deepcopy(self._metadata_buffer)
def get_metadata_json(self):
"""Converts the metadata into a json string."""
return convert_to_json(self._metadata_buffer)
def get_packed_associated_file_list(self):
"""Returns a list of associated files that are packed in the model.
Returns:
A name list of associated files.
"""
return copy.deepcopy(self._associated_file_list)
@staticmethod
def _parse_packed_associted_file_list(model_buf):
"""Gets a list of associated files packed to the model file.
Args:
model_buf: valid file buffer.
Returns:
List of packed associated files.
"""
try:
with _open_as_zipfile(io.BytesIO(model_buf)) as zf:
return zf.namelist()
except zipfile.BadZipFile:
return []
# Create an individual method for getting the metadata json file, so that it can
# be used as a standalone util.
def convert_to_json(metadata_buffer):
"""Converts the metadata into a json string.
Args:
metadata_buffer: valid metadata buffer in bytes.
Returns:
Metadata in JSON format.
Raises:
ValueError: error occured when parsing the metadata schema file.
"""
opt = _pywrap_flatbuffers.IDLOptions()
opt.strict_json = True
parser = _pywrap_flatbuffers.Parser(opt)
with _open_file(_FLATC_TFLITE_METADATA_SCHEMA_FILE) as f:
metadata_schema_content = f.read()
if not parser.parse(metadata_schema_content):
raise ValueError("Cannot parse metadata schema. Reason: " + parser.error)
return _pywrap_flatbuffers.generate_text(parser, metadata_buffer)
def _assert_file_exist(filename):
"""Checks if a file exists."""
if not _exists_file(filename):
raise IOError("File, '{0}', does not exist.".format(filename))
def _assert_model_file_identifier(model_file):
"""Checks if a model file has the expected TFLite schema identifier."""
_assert_file_exist(model_file)
with _open_file(model_file, "rb") as f:
_assert_model_buffer_identifier(f.read())
def _assert_model_buffer_identifier(model_buf):
if not _schema_fb.Model.ModelBufferHasIdentifier(model_buf, 0):
raise ValueError(
"The model provided does not have the expected identifier, and "
"may not be a valid TFLite model.")
def _assert_metadata_buffer_identifier(metadata_buf):
"""Checks if a metadata buffer has the expected Metadata schema identifier."""
if not _metadata_fb.ModelMetadata.ModelMetadataBufferHasIdentifier(
metadata_buf, 0):
raise ValueError(
"The metadata buffer does not have the expected identifier, and may not"
" be a valid TFLite Metadata.")
def get_metadata_buffer(model_buf):
"""Returns the metadata in the model file as a buffer.
Args:
model_buf: valid buffer of the model file.
Returns:
Metadata buffer. Returns `None` if the model does not have metadata.
"""
tflite_model = _schema_fb.Model.GetRootAsModel(model_buf, 0)
# Gets metadata from the model file.
for i in range(tflite_model.MetadataLength()):
meta = tflite_model.Metadata(i)
if meta.Name().decode("utf-8") == MetadataPopulator.METADATA_FIELD_NAME:
buffer_index = meta.Buffer()
metadata = tflite_model.Buffers(buffer_index)
return metadata.DataAsNumpy().tobytes()
return None

View File

@ -0,0 +1,34 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CLI tool for display metadata."""
from absl import app
from absl import flags
from mediapipe.tasks.python.metadata import metadata
FLAGS = flags.FLAGS
flags.DEFINE_string('model_path', None, 'Path to the TFLite model file.')
flags.DEFINE_string('export_json_path', None, 'Path to the output JSON file.')
def main(_):
displayer = metadata.MetadataDisplayer.with_model_file(FLAGS.model_path)
with open(FLAGS.export_json_path, 'w') as f:
f.write(displayer.get_metadata_json())
if __name__ == '__main__':
app.run(main)

Some files were not shown because too many files have changed in this diff Show More